Skip to content

Commit

Permalink
Add interface binding on connection
Browse files Browse the repository at this point in the history
  • Loading branch information
sashacmc committed Feb 23, 2024
1 parent b233822 commit e96de43
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 26 deletions.
19 changes: 15 additions & 4 deletions commons/zenoh-util/src/std_only/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,23 +510,34 @@ fn set_bind_to_device(socket: std::os::raw::c_int, iface: &Option<String>) {
}

#[cfg(target_os = "linux")]
pub fn set_bind_to_device_tcp(socket: &TcpListener, iface: &Option<String>) {
pub fn set_bind_to_device_tcp_listener(socket: &TcpListener, iface: &Option<String>) {
use std::os::fd::AsRawFd;
set_bind_to_device(socket.as_raw_fd(), iface);
}

#[cfg(target_os = "linux")]
pub fn set_bind_to_device_udp(socket: &UdpSocket, iface: &Option<String>) {
pub fn set_bind_to_device_tcp_stream(socket: &TcpStream, iface: &Option<String>) {
use std::os::fd::AsRawFd;
set_bind_to_device(socket.as_raw_fd(), iface);
}

#[cfg(target_os = "linux")]
pub fn set_bind_to_device_udp_socket(socket: &UdpSocket, iface: &Option<String>) {
use std::os::fd::AsRawFd;
set_bind_to_device(socket.as_raw_fd(), iface);
}

#[cfg(any(target_os = "macos", target_os = "windows"))]
pub fn set_bind_to_device_tcp_listener(_socket: &TcpListener, _iface: &Option<String>) {
log::warn!("Listen at the interface is not supported for this platform");
}

#[cfg(any(target_os = "macos", target_os = "windows"))]
pub fn set_bind_to_device_tcp(_socket: &TcpListener, _iface: &Option<String>) {
pub fn set_bind_to_device_tcp_stream(_socket: &TcpStream, _iface: &Option<String>) {
log::warn!("Listen at the interface is not supported for this platform");
}

#[cfg(any(target_os = "macos", target_os = "windows"))]
pub fn set_bind_to_device_udp(_socket: &UdpSocket, _iface: &Option<String>) {
pub fn set_bind_to_device_udp_socket(_socket: &UdpSocket, _iface: &Option<String>) {
log::warn!("Listen at the interface is not supported for this platform");
}
8 changes: 6 additions & 2 deletions io/zenoh-links/zenoh-link-tcp/src/unicast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ impl LinkManagerUnicastTcp {
async fn new_link_inner(
&self,
dst_addr: &SocketAddr,
iface: &Option<String>,
) -> ZResult<(TcpStream, SocketAddr, SocketAddr)> {
let stream = TcpStream::connect(dst_addr)
.await
Expand All @@ -212,6 +213,8 @@ impl LinkManagerUnicastTcp {
.peer_addr()
.map_err(|e| zerror!("{}: {}", dst_addr, e))?;

zenoh_util::net::set_bind_to_device_tcp_stream(&stream, iface);

Ok((stream, src_addr, dst_addr))
}

Expand All @@ -225,7 +228,7 @@ impl LinkManagerUnicastTcp {
.await
.map_err(|e| zerror!("{}: {}", addr, e))?;

zenoh_util::net::set_bind_to_device_tcp(&socket, iface);
zenoh_util::net::set_bind_to_device_tcp_listener(&socket, iface);

let local_addr = socket
.local_addr()
Expand All @@ -239,10 +242,11 @@ impl LinkManagerUnicastTcp {
impl LinkManagerUnicastTrait for LinkManagerUnicastTcp {
async fn new_link(&self, endpoint: EndPoint) -> ZResult<LinkUnicast> {
let dst_addrs = get_tcp_addrs(endpoint.address()).await?;
let iface = endpoint.config().get(BIND_INTERFACE).map(|s| s.to_string());

let mut errs: Vec<ZError> = vec![];
for da in dst_addrs {
match self.new_link_inner(&da).await {
match self.new_link_inner(&da, &iface).await {
Ok((stream, src_addr, dst_addr)) => {
let link = Arc::new(LinkUnicastTcp::new(stream, src_addr, dst_addr));
return Ok(LinkUnicast(link));
Expand Down
8 changes: 6 additions & 2 deletions io/zenoh-links/zenoh-link-udp/src/unicast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ impl LinkManagerUnicastUdp {
async fn new_link_inner(
&self,
dst_addr: &SocketAddr,
iface: &Option<String>,
) -> ZResult<(UdpSocket, SocketAddr, SocketAddr)> {
// Establish a UDP socket
let socket = UdpSocket::bind(SocketAddr::new(
Expand All @@ -278,6 +279,8 @@ impl LinkManagerUnicastUdp {
e
})?;

zenoh_util::net::set_bind_to_device_udp_socket(&socket, iface);

// Connect the socket to the remote address
socket.connect(dst_addr).await.map_err(|e| {
let e = zerror!("Can not create a new UDP link bound to {}: {}", dst_addr, e);
Expand Down Expand Up @@ -313,7 +316,7 @@ impl LinkManagerUnicastUdp {
e
})?;

zenoh_util::net::set_bind_to_device_udp(&socket, iface);
zenoh_util::net::set_bind_to_device_udp_socket(&socket, iface);

let local_addr = socket.local_addr().map_err(|e| {
let e = zerror!("Can not create a new UDP listener on {}: {}", addr, e);
Expand All @@ -331,10 +334,11 @@ impl LinkManagerUnicastTrait for LinkManagerUnicastUdp {
let dst_addrs = get_udp_addrs(endpoint.address())
.await?
.filter(|a| !a.ip().is_multicast());
let iface = endpoint.config().get(BIND_INTERFACE).map(|s| s.to_string());

let mut errs: Vec<ZError> = vec![];
for da in dst_addrs {
match self.new_link_inner(&da).await {
match self.new_link_inner(&da, &iface).await {
Ok((socket, src_addr, dst_addr)) => {
// Create UDP link
let link = Arc::new(LinkUnicastUdp::new(
Expand Down
114 changes: 96 additions & 18 deletions io/zenoh-transport/tests/unicast_openclose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use zenoh_transport::{
use zenoh_util::net::get_ipv4_ipaddrs;

const TIMEOUT: Duration = Duration::from_secs(60);
const TIMEOUT_EXPECTED: Duration = Duration::from_secs(5);
const SLEEP: Duration = Duration::from_millis(100);

macro_rules! ztimeout {
Expand All @@ -36,6 +37,12 @@ macro_rules! ztimeout {
};
}

macro_rules! ztimeout_expected {
($f:expr) => {
$f.timeout(TIMEOUT_EXPECTED).await.unwrap()
};
}

#[cfg(test)]
#[derive(Default)]
struct SHRouterOpenClose;
Expand Down Expand Up @@ -83,7 +90,11 @@ impl TransportEventHandler for SHClientOpenClose {
}
}

async fn openclose_transport(endpoint: &EndPoint, lowlatency_transport: bool) {
async fn openclose_transport(
listen_endpoint: &EndPoint,
connect_endpoint: &EndPoint,
lowlatency_transport: bool,
) {
/* [ROUTER] */
let router_id = ZenohId::try_from([1]).unwrap();

Expand Down Expand Up @@ -143,7 +154,7 @@ async fn openclose_transport(endpoint: &EndPoint, lowlatency_transport: bool) {
/* [1] */
println!("\nTransport Open Close [1a1]");
// Add the locator on the router
let res = ztimeout!(router_manager.add_listener(endpoint.clone()));
let res = ztimeout!(router_manager.add_listener(listen_endpoint.clone()));
println!("Transport Open Close [1a1]: {res:?}");
assert!(res.is_ok());
println!("Transport Open Close [1a2]");
Expand All @@ -156,7 +167,8 @@ async fn openclose_transport(endpoint: &EndPoint, lowlatency_transport: bool) {
let mut links_num = 1;

println!("Transport Open Close [1c1]");
let open_res = ztimeout!(client01_manager.open_transport_unicast(endpoint.clone()));
let open_res =
ztimeout_expected!(client01_manager.open_transport_unicast(connect_endpoint.clone()));
println!("Transport Open Close [1c2]: {res:?}");
assert!(open_res.is_ok());
let c_ses1 = open_res.unwrap();
Expand Down Expand Up @@ -198,7 +210,7 @@ async fn openclose_transport(endpoint: &EndPoint, lowlatency_transport: bool) {
links_num = 2;

println!("\nTransport Open Close [2a1]");
let res = ztimeout!(client01_manager.open_transport_unicast(endpoint.clone()));
let res = ztimeout!(client01_manager.open_transport_unicast(connect_endpoint.clone()));
println!("Transport Open Close [2a2]: {res:?}");
assert!(res.is_ok());
let c_ses2 = res.unwrap();
Expand Down Expand Up @@ -238,7 +250,7 @@ async fn openclose_transport(endpoint: &EndPoint, lowlatency_transport: bool) {
// Open transport -> This should be rejected because
// of the maximum limit of links per transport
println!("\nTransport Open Close [3a1]");
let res = ztimeout!(client01_manager.open_transport_unicast(endpoint.clone()));
let res = ztimeout!(client01_manager.open_transport_unicast(connect_endpoint.clone()));
println!("Transport Open Close [3a2]: {res:?}");
assert!(res.is_err());
println!("Transport Open Close [3b1]");
Expand Down Expand Up @@ -297,7 +309,7 @@ async fn openclose_transport(endpoint: &EndPoint, lowlatency_transport: bool) {
links_num = 1;

println!("\nTransport Open Close [5a1]");
let res = ztimeout!(client01_manager.open_transport_unicast(endpoint.clone()));
let res = ztimeout!(client01_manager.open_transport_unicast(connect_endpoint.clone()));
println!("Transport Open Close [5a2]: {res:?}");
assert!(res.is_ok());
let c_ses3 = res.unwrap();
Expand Down Expand Up @@ -329,7 +341,7 @@ async fn openclose_transport(endpoint: &EndPoint, lowlatency_transport: bool) {
// Open transport -> This should be rejected because
// of the maximum limit of transports
println!("\nTransport Open Close [6a1]");
let res = ztimeout!(client02_manager.open_transport_unicast(endpoint.clone()));
let res = ztimeout!(client02_manager.open_transport_unicast(connect_endpoint.clone()));
println!("Transport Open Close [6a2]: {res:?}");
assert!(res.is_err());
println!("Transport Open Close [6b1]");
Expand Down Expand Up @@ -380,7 +392,7 @@ async fn openclose_transport(endpoint: &EndPoint, lowlatency_transport: bool) {
links_num = 1;

println!("\nTransport Open Close [8a1]");
let res = ztimeout!(client02_manager.open_transport_unicast(endpoint.clone()));
let res = ztimeout!(client02_manager.open_transport_unicast(connect_endpoint.clone()));
println!("Transport Open Close [8a2]: {res:?}");
assert!(res.is_ok());
let c_ses4 = res.unwrap();
Expand Down Expand Up @@ -438,7 +450,7 @@ async fn openclose_transport(endpoint: &EndPoint, lowlatency_transport: bool) {
/* [10] */
// Perform clean up of the open locators
println!("\nTransport Open Close [10a1]");
let res = ztimeout!(router_manager.del_listener(endpoint));
let res = ztimeout!(router_manager.del_listener(listen_endpoint));
println!("Transport Open Close [10a2]: {res:?}");
assert!(res.is_ok());

Expand All @@ -460,11 +472,11 @@ async fn openclose_transport(endpoint: &EndPoint, lowlatency_transport: bool) {
}

async fn openclose_universal_transport(endpoint: &EndPoint) {
openclose_transport(endpoint, false).await
openclose_transport(endpoint, endpoint, false).await
}

async fn openclose_lowlatency_transport(endpoint: &EndPoint) {
openclose_transport(endpoint, true).await
openclose_transport(endpoint, endpoint, true).await
}

#[cfg(feature = "transport_tcp")]
Expand Down Expand Up @@ -790,40 +802,106 @@ R+IdLiXcyIkg0m9N8I17p0ljCSkbrgGMD3bbePRTfg==
task::block_on(openclose_universal_transport(&endpoint));
}

#[cfg(feature = "transport_tcp")]
#[cfg(target_os = "linux")]
#[test]
#[should_panic(expected = "TimeoutError")]
fn openclose_tcp_only_connect_with_interface_restriction() {
let addrs = get_ipv4_ipaddrs(None);

let _ = env_logger::try_init();
task::block_on(async {
zasync_executor_init!();
});

let listen_endpoint: EndPoint = format!("tcp/{}:{}", addrs[0], 13001).parse().unwrap();

let connect_endpoint: EndPoint = format!("tcp/{}:{}#iface=lo", addrs[0], 13001)
.parse()
.unwrap();

// should not connect to local interface and external address
task::block_on(openclose_transport(
&listen_endpoint,
&connect_endpoint,
false,
));
}

#[cfg(feature = "transport_tcp")]
#[cfg(target_os = "linux")]
#[test]
#[should_panic(expected = "assertion failed: open_res.is_ok()")]
fn openclose_tcp_only_with_interface_restriction() {
fn openclose_tcp_only_listen_with_interface_restriction() {
let addrs = get_ipv4_ipaddrs(None);

let _ = env_logger::try_init();
task::block_on(async {
zasync_executor_init!();
});

let listen_endpoint: EndPoint = format!("tcp/{}:{}#iface=lo", addrs[0], 13002)
.parse()
.unwrap();

let connect_endpoint: EndPoint = format!("tcp/{}:{}", addrs[0], 13002).parse().unwrap();

// should not connect to local interface and external address
let endpoint: EndPoint = format!("tcp/{}:{}#iface=lo", addrs[0], 13001)
task::block_on(openclose_transport(
&listen_endpoint,
&connect_endpoint,
false,
));
}

#[cfg(feature = "transport_udp")]
#[cfg(target_os = "linux")]
#[test]
#[should_panic(expected = "TimeoutError")]
fn openclose_udp_only_connect_with_interface_restriction() {
let addrs = get_ipv4_ipaddrs(None);

let _ = env_logger::try_init();
task::block_on(async {
zasync_executor_init!();
});

let listen_endpoint: EndPoint = format!("udp/{}:{}", addrs[0], 13003).parse().unwrap();

let connect_endpoint: EndPoint = format!("udp/{}:{}#iface=lo", addrs[0], 13003)
.parse()
.unwrap();
task::block_on(openclose_universal_transport(&endpoint));

// should not connect to local interface and external address
task::block_on(openclose_transport(
&listen_endpoint,
&connect_endpoint,
false,
));
}

#[cfg(feature = "transport_udp")]
#[cfg(target_os = "linux")]
#[test]
#[should_panic(expected = "assertion failed: open_res.is_ok()")]
fn openclose_udp_only_with_interface_restriction() {
fn openclose_udp_onlyi_listen_with_interface_restriction() {
let addrs = get_ipv4_ipaddrs(None);

let _ = env_logger::try_init();
task::block_on(async {
zasync_executor_init!();
});

// should not connect to local interface and external address
let endpoint: EndPoint = format!("udp/{}:{}#iface=lo", addrs[0], 13011)
let listen_endpoint: EndPoint = format!("udp/{}:{}#iface=lo", addrs[0], 13004)
.parse()
.unwrap();
task::block_on(openclose_universal_transport(&endpoint));

let connect_endpoint: EndPoint = format!("udp/{}:{}", addrs[0], 13004).parse().unwrap();

// should not connect to local interface and external address
task::block_on(openclose_transport(
&listen_endpoint,
&connect_endpoint,
false,
));
}

0 comments on commit e96de43

Please sign in to comment.