diff --git a/bevy_renet/examples/simple.rs b/bevy_renet/examples/simple.rs index 5244b526..8c169b10 100644 --- a/bevy_renet/examples/simple.rs +++ b/bevy_renet/examples/simple.rs @@ -9,7 +9,7 @@ use bevy_renet::{ RenetClientPlugin, RenetServerPlugin, }; use renet::{ - transport::{NetcodeClientTransport, NetcodeServerTransport, NetcodeTransportError}, + transport::{NativeSocket, NetcodeClientTransport, NetcodeServerTransport, NetcodeTransportError}, ClientId, }; @@ -58,7 +58,7 @@ fn new_renet_client() -> (RenetClient, NetcodeClientTransport) { user_data: None, }; - let transport = NetcodeClientTransport::new(current_time, authentication, socket).unwrap(); + let transport = NetcodeClientTransport::new(current_time, authentication, NativeSocket::new(socket).unwrap()).unwrap(); let client = RenetClient::new(ConnectionConfig::default()); (client, transport) @@ -76,7 +76,7 @@ fn new_renet_server() -> (RenetServer, NetcodeServerTransport) { authentication: ServerAuthentication::Unsecure, }; - let transport = NetcodeServerTransport::new(server_config, socket).unwrap(); + let transport = NetcodeServerTransport::new(server_config, NativeSocket::new(socket).unwrap()).unwrap(); let server = RenetServer::new(ConnectionConfig::default()); (server, transport) diff --git a/demo_bevy/src/bin/client.rs b/demo_bevy/src/bin/client.rs index c742929e..0d45773e 100644 --- a/demo_bevy/src/bin/client.rs +++ b/demo_bevy/src/bin/client.rs @@ -8,7 +8,7 @@ use bevy::{ use bevy_egui::{EguiContexts, EguiPlugin}; use bevy_renet::{ client_connected, - renet::{ClientId, RenetClient}, + renet::{transport::NativeSocket, ClientId, RenetClient}, RenetClientPlugin, }; use demo_bevy::{ @@ -63,7 +63,7 @@ fn add_netcode_network(app: &mut App) { user_data: None, }; - let transport = NetcodeClientTransport::new(current_time, authentication, socket).unwrap(); + let transport = NetcodeClientTransport::new(current_time, authentication, NativeSocket::new(socket).unwrap()).unwrap(); app.insert_resource(client); app.insert_resource(transport); diff --git a/demo_bevy/src/bin/server.rs b/demo_bevy/src/bin/server.rs index c59ed34f..3c7e5551 100644 --- a/demo_bevy/src/bin/server.rs +++ b/demo_bevy/src/bin/server.rs @@ -6,7 +6,7 @@ use bevy::{ }; use bevy_egui::{EguiContexts, EguiPlugin}; use bevy_renet::{ - renet::{ClientId, RenetServer, ServerEvent}, + renet::{transport::NativeSocket, ClientId, RenetServer, ServerEvent}, RenetServerPlugin, }; use demo_bevy::{ @@ -52,7 +52,7 @@ fn add_netcode_network(app: &mut App) { authentication: ServerAuthentication::Unsecure, }; - let transport = NetcodeServerTransport::new(server_config, socket).unwrap(); + let transport = NetcodeServerTransport::new(server_config, NativeSocket::new(socket).unwrap()).unwrap(); app.insert_resource(server); app.insert_resource(transport); } diff --git a/demo_chat/src/server.rs b/demo_chat/src/server.rs index 07f9a967..463d759e 100644 --- a/demo_chat/src/server.rs +++ b/demo_chat/src/server.rs @@ -6,7 +6,7 @@ use std::{ }; use renet::{ - transport::{NetcodeServerTransport, ServerAuthentication, ServerConfig}, + transport::{NativeSocket, NetcodeServerTransport, ServerAuthentication, ServerConfig}, ClientId, ConnectionConfig, DefaultChannel, RenetServer, ServerEvent, }; use renet_visualizer::RenetServerVisualizer; @@ -38,7 +38,7 @@ impl ChatServer { authentication: ServerAuthentication::Unsecure, }; - let transport = NetcodeServerTransport::new(server_config, socket).unwrap(); + let transport = NetcodeServerTransport::new(server_config, NativeSocket::new(socket).unwrap()).unwrap(); let server: RenetServer = RenetServer::new(ConnectionConfig::default()); diff --git a/demo_chat/src/ui.rs b/demo_chat/src/ui.rs index b19572a3..fbeda1af 100644 --- a/demo_chat/src/ui.rs +++ b/demo_chat/src/ui.rs @@ -4,7 +4,7 @@ use eframe::{ epaint::PathShape, }; use renet::{ - transport::{ClientAuthentication, NetcodeClientTransport}, + transport::{ClientAuthentication, NativeSocket, NetcodeClientTransport}, ClientId, ConnectionConfig, DefaultChannel, RenetClient, }; @@ -267,7 +267,7 @@ fn create_renet_client(username: String, server_addr: SocketAddr) -> (RenetClien protocol_id: PROTOCOL_ID, }; - let transport = NetcodeClientTransport::new(current_time, authentication, socket).unwrap(); + let transport = NetcodeClientTransport::new(current_time, authentication, NativeSocket::new(socket).unwrap()).unwrap(); (client, transport) } diff --git a/renet/examples/echo.rs b/renet/examples/echo.rs index 2cb043e0..6e2f475e 100644 --- a/renet/examples/echo.rs +++ b/renet/examples/echo.rs @@ -8,7 +8,8 @@ use std::{ use renet::{ transport::{ - ClientAuthentication, NetcodeClientTransport, NetcodeServerTransport, ServerAuthentication, ServerConfig, NETCODE_USER_DATA_BYTES, + ClientAuthentication, NativeSocket, NetcodeClientTransport, NetcodeServerTransport, ServerAuthentication, ServerConfig, + NETCODE_USER_DATA_BYTES, }, ClientId, ConnectionConfig, DefaultChannel, RenetClient, RenetServer, ServerEvent, }; @@ -77,7 +78,7 @@ fn server(public_addr: SocketAddr) { }; let socket: UdpSocket = UdpSocket::bind(public_addr).unwrap(); - let mut transport = NetcodeServerTransport::new(server_config, socket).unwrap(); + let mut transport = NetcodeServerTransport::new(server_config, NativeSocket::new(socket).unwrap()).unwrap(); let mut usernames: HashMap = HashMap::new(); let mut received_messages = vec![]; @@ -152,7 +153,7 @@ fn client(server_addr: SocketAddr, username: Username) { protocol_id: PROTOCOL_ID, }; - let mut transport = NetcodeClientTransport::new(current_time, authentication, socket).unwrap(); + let mut transport = NetcodeClientTransport::new(current_time, authentication, NativeSocket::new(socket).unwrap()).unwrap(); let stdin_channel: Receiver = spawn_stdin_channel(); let mut last_updated = Instant::now(); diff --git a/renet/src/transport/client.rs b/renet/src/transport/client.rs index feaac24b..e5d96753 100644 --- a/renet/src/transport/client.rs +++ b/renet/src/transport/client.rs @@ -1,37 +1,32 @@ -use std::{ - io, - net::{SocketAddr, UdpSocket}, - time::Duration, -}; +use std::{io, net::SocketAddr, time::Duration}; use renetcode::{ClientAuthentication, DisconnectReason, NetcodeClient, NetcodeError, NETCODE_MAX_PACKET_BYTES}; use crate::{remote_connection::RenetClient, ClientId}; -use super::NetcodeTransportError; +use super::{NetcodeTransportError, TransportSocket}; #[derive(Debug)] #[cfg_attr(feature = "bevy", derive(bevy_ecs::system::Resource))] pub struct NetcodeClientTransport { - socket: UdpSocket, + socket: Box, netcode_client: NetcodeClient, buffer: [u8; NETCODE_MAX_PACKET_BYTES], } impl NetcodeClientTransport { - pub fn new(current_time: Duration, authentication: ClientAuthentication, socket: UdpSocket) -> Result { - socket.set_nonblocking(true)?; + pub fn new(current_time: Duration, authentication: ClientAuthentication, socket: impl TransportSocket) -> Result { let netcode_client = NetcodeClient::new(current_time, authentication)?; Ok(Self { - buffer: [0u8; NETCODE_MAX_PACKET_BYTES], - socket, + socket: Box::new(socket), netcode_client, + buffer: [0u8; NETCODE_MAX_PACKET_BYTES], }) } pub fn addr(&self) -> io::Result { - self.socket.local_addr() + self.socket.addr() } pub fn client_id(&self) -> ClientId { @@ -39,12 +34,14 @@ impl NetcodeClientTransport { } /// Returns the duration since the client last received a packet. - /// Usefull to detect timeouts. + /// + /// Useful to detect timeouts. pub fn time_since_last_received_packet(&self) -> Duration { self.netcode_client.time_since_last_received_packet() } /// Disconnect the client from the transport layer. + /// /// This sends the disconnect packet instantly, use this when closing/exiting games, /// should use [RenetClient::disconnect][crate::RenetClient::disconnect] otherwise. pub fn disconnect(&mut self) { @@ -54,7 +51,7 @@ impl NetcodeClientTransport { match self.netcode_client.disconnect() { Ok((addr, packet)) => { - if let Err(e) = self.socket.send_to(packet, addr) { + if let Err(e) = self.socket.send(addr, packet) { log::error!("Failed to send disconnect packet: {e}"); } } @@ -68,7 +65,8 @@ impl NetcodeClientTransport { } /// Send packets to the server. - /// Should be called every tick + /// + /// Should be called every tick. pub fn send_packets(&mut self, connection: &mut RenetClient) -> Result<(), NetcodeTransportError> { if let Some(reason) = self.netcode_client.disconnect_reason() { return Err(NetcodeError::Disconnected(reason).into()); @@ -77,7 +75,7 @@ impl NetcodeClientTransport { let packets = connection.get_packets_to_send(); for packet in packets { let (addr, payload) = self.netcode_client.generate_payload_packet(&packet)?; - self.socket.send_to(payload, addr)?; + self.socket.send(addr, payload)?; } Ok(()) @@ -88,13 +86,21 @@ impl NetcodeClientTransport { if let Some(reason) = self.netcode_client.disconnect_reason() { // Mark the client as disconnected if an error occured in the transport layer client.disconnect_due_to_transport(); + self.socket.close(); return Err(NetcodeError::Disconnected(reason).into()); } + if self.socket.is_closed() { + client.disconnect_due_to_transport(); + } + if let Some(error) = client.disconnect_reason() { let (addr, disconnect_packet) = self.netcode_client.disconnect()?; - self.socket.send_to(disconnect_packet, addr)?; + if !self.socket.is_closed() { + self.socket.send(addr, disconnect_packet)?; + self.socket.close(); + } return Err(error.into()); } @@ -104,8 +110,10 @@ impl NetcodeClientTransport { client.set_connecting(); } + self.socket.preupdate(); + loop { - let packet = match self.socket.recv_from(&mut self.buffer) { + let packet = match self.socket.try_recv(&mut self.buffer) { Ok((len, addr)) => { if addr != self.netcode_client.server_addr() { log::debug!("Discarded packet from unknown server {:?}", addr); @@ -125,9 +133,11 @@ impl NetcodeClientTransport { } if let Some((packet, addr)) = self.netcode_client.update(duration) { - self.socket.send_to(packet, addr)?; + self.socket.send(addr, packet)?; } + self.socket.postupdate(); + Ok(()) } } diff --git a/renet/src/transport/mod.rs b/renet/src/transport/mod.rs index 5434e026..fa6ed804 100644 --- a/renet/src/transport/mod.rs +++ b/renet/src/transport/mod.rs @@ -1,10 +1,14 @@ use std::{error::Error, fmt}; mod client; +mod native_socket; mod server; +mod transport_socket; pub use client::*; +pub use native_socket::*; pub use server::*; +pub use transport_socket::*; pub use renetcode::{ generate_random_bytes, ClientAuthentication, ConnectToken, DisconnectReason as NetcodeDisconnectReason, NetcodeError, diff --git a/renet/src/transport/native_socket.rs b/renet/src/transport/native_socket.rs new file mode 100644 index 00000000..7daefea6 --- /dev/null +++ b/renet/src/transport/native_socket.rs @@ -0,0 +1,42 @@ +use std::net::{SocketAddr, UdpSocket}; + +use super::{NetcodeError, NetcodeTransportError, TransportSocket}; + +/// Implementation of [`TransportSocket`] for `UdpSockets`. +#[derive(Debug)] +pub struct NativeSocket { + socket: UdpSocket, +} + +impl NativeSocket { + /// Makes a new native socket. + pub fn new(socket: UdpSocket) -> Result { + socket.set_nonblocking(true)?; + Ok(Self { socket }) + } +} + +impl TransportSocket for NativeSocket { + fn addr(&self) -> std::io::Result { + self.socket.local_addr() + } + + fn is_closed(&mut self) -> bool { + false + } + + fn close(&mut self) {} + fn disconnect(&mut self, _: SocketAddr) {} + fn preupdate(&mut self) {} + + fn try_recv(&mut self, buffer: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> { + self.socket.recv_from(buffer) + } + + fn postupdate(&mut self) {} + + fn send(&mut self, addr: SocketAddr, packet: &[u8]) -> Result<(), NetcodeTransportError> { + self.socket.send_to(packet, addr)?; + Ok(()) + } +} diff --git a/renet/src/transport/server.rs b/renet/src/transport/server.rs index 3cf286df..57eef316 100644 --- a/renet/src/transport/server.rs +++ b/renet/src/transport/server.rs @@ -1,32 +1,26 @@ -use std::{ - io, - net::{SocketAddr, UdpSocket}, - time::Duration, -}; +use std::{io, net::SocketAddr, time::Duration}; use renetcode::{NetcodeServer, ServerConfig, ServerResult, NETCODE_MAX_PACKET_BYTES, NETCODE_USER_DATA_BYTES}; use crate::ClientId; use crate::RenetServer; -use super::NetcodeTransportError; +use super::{NetcodeTransportError, TransportSocket}; #[derive(Debug)] #[cfg_attr(feature = "bevy", derive(bevy_ecs::system::Resource))] pub struct NetcodeServerTransport { - socket: UdpSocket, + socket: Box, netcode_server: NetcodeServer, buffer: [u8; NETCODE_MAX_PACKET_BYTES], } impl NetcodeServerTransport { - pub fn new(server_config: ServerConfig, socket: UdpSocket) -> Result { - socket.set_nonblocking(true)?; - + pub fn new(server_config: ServerConfig, socket: impl TransportSocket) -> Result { let netcode_server = NetcodeServer::new(server_config); Ok(Self { - socket, + socket: Box::new(socket), netcode_server, buffer: [0; NETCODE_MAX_PACKET_BYTES], }) @@ -63,7 +57,7 @@ impl NetcodeServerTransport { pub fn disconnect_all(&mut self, server: &mut RenetServer) { for client_id in self.netcode_server.clients_id() { let server_result = self.netcode_server.disconnect(client_id); - handle_server_result(server_result, &self.socket, server); + handle_server_result(server_result, &mut self.socket, server); } } @@ -77,11 +71,13 @@ impl NetcodeServerTransport { pub fn update(&mut self, duration: Duration, server: &mut RenetServer) -> Result<(), NetcodeTransportError> { self.netcode_server.update(duration); + self.socket.preupdate(); + loop { - match self.socket.recv_from(&mut self.buffer) { + match self.socket.try_recv(&mut self.buffer) { Ok((len, addr)) => { let server_result = self.netcode_server.process_packet(addr, &mut self.buffer[..len]); - handle_server_result(server_result, &self.socket, server); + handle_server_result(server_result, &mut self.socket, server); } Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break, Err(ref e) if e.kind() == io::ErrorKind::Interrupted => break, @@ -92,14 +88,16 @@ impl NetcodeServerTransport { for client_id in self.netcode_server.clients_id() { let server_result = self.netcode_server.update_client(client_id); - handle_server_result(server_result, &self.socket, server); + handle_server_result(server_result, &mut self.socket, server); } for disconnection_id in server.disconnections_id() { let server_result = self.netcode_server.disconnect(disconnection_id.raw()); - handle_server_result(server_result, &self.socket, server); + handle_server_result(server_result, &mut self.socket, server); } + self.socket.postupdate(); + Ok(()) } @@ -110,7 +108,7 @@ impl NetcodeServerTransport { for packet in packets { match self.netcode_server.generate_payload_packet(client_id.raw(), &packet) { Ok((addr, payload)) => { - if let Err(e) = self.socket.send_to(payload, addr) { + if let Err(e) = self.socket.send(addr, payload) { log::error!("Failed to send packet to client {client_id} ({addr}): {e}"); continue 'clients; } @@ -125,9 +123,9 @@ impl NetcodeServerTransport { } } -fn handle_server_result(server_result: ServerResult, socket: &UdpSocket, reliable_server: &mut RenetServer) { - let send_packet = |packet: &[u8], addr: SocketAddr| { - if let Err(err) = socket.send_to(packet, addr) { +fn handle_server_result(server_result: ServerResult, socket: &mut Box, reliable_server: &mut RenetServer) { + let mut send_packet = |packet: &[u8], addr: SocketAddr| { + if let Err(err) = socket.send(addr, packet) { log::error!("Failed to send packet to {addr}: {err}"); } }; @@ -157,6 +155,7 @@ fn handle_server_result(server_result: ServerResult, socket: &UdpSocket, reliabl if let Some(payload) = payload { send_packet(payload, addr); } + socket.disconnect(addr); } } } diff --git a/renet/src/transport/transport_socket.rs b/renet/src/transport/transport_socket.rs new file mode 100644 index 00000000..983ed148 --- /dev/null +++ b/renet/src/transport/transport_socket.rs @@ -0,0 +1,44 @@ +use std::fmt::Debug; +use std::net::SocketAddr; + +use super::NetcodeTransportError; + +/// Unreliable data source for use in [`NetcodeServerTransport`] and [`NetcodeClientTransport`]. +/// +/// Note that while `netcode` uses `SocketAddr` everywhere, if your transport uses a different 'connection URL' +/// scheme then you can layer the bytes into the [`ConnectToken`](renet::ConnectToken) server address list. +/// Just keep in mind that when a client disconnects, the client will traverse the server address list to find +/// an address to reconnect with. If that isn't supported by your scheme, then when [`TransportSocket::send`] is +/// called with an invalid/unexpected server address you should return an error. If you want to support +/// multiple server addresses but your URLs exceed 16 bytes (IPV6 addresses are 16 bytes), then you should pre-parse +/// the server list from the connect token, and then map that list to the 16-byte IPV6 segments that will be produced +/// by the client when it tries to reconnect to different servers. +pub trait TransportSocket: Debug + Send + Sync + 'static { + /// Gets the data source's `SocketAddr`. + fn addr(&self) -> std::io::Result; + + /// Checks if the data source is closed. + fn is_closed(&mut self) -> bool; + + /// Closes the data source. + /// + /// This should disconnect any remote connections that are being tracked. + fn close(&mut self); + + /// Disconnects a remote connection with the given address. + fn disconnect(&mut self, addr: SocketAddr); + + /// Handles data-source-specific logic that must run before receiving packets. + fn preupdate(&mut self); + + /// Tries to receive the next packet sent to this data source. + /// + /// Returns the number of bytes written to the buffer, and the source address of the bytes. + fn try_recv(&mut self, buffer: &mut [u8]) -> std::io::Result<(usize, SocketAddr)>; + + /// Handles data-source-specific logic that must run after sending packets. + fn postupdate(&mut self); + + /// Sends a packet to the designated address. + fn send(&mut self, addr: SocketAddr, packet: &[u8]) -> Result<(), NetcodeTransportError>; +}