diff --git a/backend/src/net/utils.rs b/backend/src/net/utils.rs index 1736912b2..f5d8f1498 100644 --- a/backend/src/net/utils.rs +++ b/backend/src/net/utils.rs @@ -1,5 +1,5 @@ use std::convert::Infallible; -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; use std::path::Path; use async_stream::try_stream; @@ -7,24 +7,29 @@ use color_eyre::eyre::eyre; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; use ipnet::{Ipv4Net, Ipv6Net}; +use tokio::net::{TcpListener, TcpStream}; use tokio::process::Command; use crate::util::Invoke; use crate::Error; -fn parse_iface_ip(output: &str) -> Result, Error> { +fn parse_iface_ip(output: &str) -> Result, Error> { let output = output.trim(); if output.is_empty() { - return Ok(None); + return Ok(Vec::new()); } - if let Some(ip) = output.split_ascii_whitespace().nth(3) { - Ok(Some(ip)) - } else { - Err(Error::new( - eyre!("malformed output from `ip`"), - crate::ErrorKind::Network, - )) + let mut res = Vec::new(); + for line in output.lines() { + if let Some(ip) = line.split_ascii_whitespace().nth(3) { + res.push(ip) + } else { + return Err(Error::new( + eyre!("malformed output from `ip`"), + crate::ErrorKind::Network, + )); + } } + Ok(res) } pub async fn get_iface_ipv4_addr(iface: &str) -> Result, Error> { @@ -38,7 +43,9 @@ pub async fn get_iface_ipv4_addr(iface: &str) -> Result((s.split("/").next().unwrap().parse()?, s.parse()?))) + .next() .transpose()?) } @@ -53,6 +60,8 @@ pub async fn get_iface_ipv6_addr(iface: &str) -> Result((s.split("/").next().unwrap().parse()?, s.parse()?))) .transpose()?) } @@ -121,3 +130,37 @@ impl hyper::server::accept::Accept for SingleAccept { std::task::Poll::Ready(self.project().0.take().map(Ok)) } } + +pub struct TcpListeners { + listeners: Vec, +} +impl TcpListeners { + pub fn new(listeners: impl IntoIterator) -> Self { + Self { + listeners: listeners.into_iter().collect(), + } + } + + pub async fn accept(&self) -> std::io::Result<(TcpStream, SocketAddr)> { + futures::future::select_all(self.listeners.iter().map(|l| Box::pin(l.accept()))) + .await + .0 + } +} +impl hyper::server::accept::Accept for TcpListeners { + type Conn = TcpStream; + type Error = std::io::Error; + + fn poll_accept( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll>> { + for listener in self.listeners.iter() { + let poll = listener.poll_accept(cx); + if poll.is_ready() { + return poll.map(Some); + } + } + std::task::Poll::Pending + } +} diff --git a/backend/src/net/vhost.rs b/backend/src/net/vhost.rs index 73c6aa37c..94952bd43 100644 --- a/backend/src/net/vhost.rs +++ b/backend/src/net/vhost.rs @@ -18,7 +18,7 @@ use tokio_rustls::{LazyConfigAcceptor, TlsConnector}; use crate::net::keys::Key; use crate::net::ssl::SslManager; -use crate::net::utils::SingleAccept; +use crate::net::utils::{SingleAccept, TcpListeners}; use crate::util::io::BackTrackingReader; use crate::Error; @@ -88,21 +88,20 @@ struct VHostServer { impl VHostServer { async fn new(port: u16, ssl: Arc) -> Result { // check if port allowed - let listener = TcpListener::bind( - [ - SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 80), - SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 80), - ] - .as_ref(), - ) - .await - .with_kind(crate::ErrorKind::Network)?; + let listeners = TcpListeners::new([ + TcpListener::bind(SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), port)) + .await + .with_kind(crate::ErrorKind::Network)?, + TcpListener::bind(SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), port)) + .await + .with_kind(crate::ErrorKind::Network)?, + ]); let mapping = Arc::new(RwLock::new(BTreeMap::new())); Ok(Self { mapping: Arc::downgrade(&mapping), _thread: tokio::spawn(async move { loop { - match listener.accept().await { + match listeners.accept().await { Ok((stream, _)) => { let mut stream = BackTrackingReader::new(stream); stream.start_buffering();