fix: extract hairpin check into platform-conditional function

The hairpin NAT check uses Linux-specific APIs (bind_device, raw fd
conversion). Extract it into a separate function with #[cfg(target_os)]
so the entire block is excluded on non-Linux platforms, rather than
guarding only the unsafe block.
This commit is contained in:
Aiden McClelland
2026-03-30 14:38:13 -06:00
parent 0d4dcf6c61
commit ce1da028ce

View File

@@ -290,21 +290,7 @@ pub async fn check_port(
));
};
let hairpinning = tokio::time::timeout(Duration::from_secs(5), async {
let dest = SocketAddr::new(ip.into(), port);
let socket = socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::STREAM, None)?;
socket.bind_device(Some(gateway.as_str().as_bytes()))?;
socket.bind(&SocketAddr::new(IpAddr::V4(local_ipv4), 0).into())?;
socket.set_nonblocking(true)?;
#[cfg(unix)]
let socket = unsafe {
use std::os::fd::{FromRawFd, IntoRawFd};
tokio::net::TcpSocket::from_raw_fd(socket.into_raw_fd())
};
socket.connect(dest).await.map(|_| ())
})
.await
.map_or(false, |r| r.is_ok());
let hairpinning = check_hairpin(gateway, local_ipv4, ip, port).await;
Ok(CheckPortRes {
ip,
@@ -315,6 +301,30 @@ pub async fn check_port(
})
}
#[cfg(target_os = "linux")]
async fn check_hairpin(gateway: GatewayId, local_ipv4: Ipv4Addr, ip: Ipv4Addr, port: u16) -> bool {
let hairpinning = tokio::time::timeout(Duration::from_secs(5), async {
let dest = SocketAddr::new(ip.into(), port);
let socket = socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::STREAM, None)?;
socket.bind_device(Some(gateway.as_str().as_bytes()))?;
socket.bind(&SocketAddr::new(IpAddr::V4(local_ipv4), 0).into())?;
socket.set_nonblocking(true)?;
let socket = unsafe {
use std::os::fd::{FromRawFd, IntoRawFd};
tokio::net::TcpSocket::from_raw_fd(socket.into_raw_fd())
};
socket.connect(dest).await.map(|_| ())
})
.await
.map_or(false, |r| r.is_ok());
hairpinning
}
#[cfg(not(target_os = "linux"))]
async fn check_hairpin(_: GatewayId, _: Ipv4Addr, _: Ipv4Addr, _: u16) -> bool {
false
}
#[derive(Debug, Clone, Deserialize, Serialize, Parser, TS)]
#[group(skip)]
#[serde(rename_all = "camelCase")]