This commit is contained in:
Aiden McClelland
2023-06-15 18:37:27 -06:00
committed by Aiden McClelland
parent 04bf5f58d9
commit ccbf71c5e7
2 changed files with 63 additions and 21 deletions

View File

@@ -1,5 +1,5 @@
use std::convert::Infallible; use std::convert::Infallible;
use std::net::{Ipv4Addr, Ipv6Addr}; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::path::Path; use std::path::Path;
use async_stream::try_stream; use async_stream::try_stream;
@@ -7,24 +7,29 @@ use color_eyre::eyre::eyre;
use futures::stream::BoxStream; use futures::stream::BoxStream;
use futures::{StreamExt, TryStreamExt}; use futures::{StreamExt, TryStreamExt};
use ipnet::{Ipv4Net, Ipv6Net}; use ipnet::{Ipv4Net, Ipv6Net};
use tokio::net::{TcpListener, TcpStream};
use tokio::process::Command; use tokio::process::Command;
use crate::util::Invoke; use crate::util::Invoke;
use crate::Error; use crate::Error;
fn parse_iface_ip(output: &str) -> Result<Option<&str>, Error> { fn parse_iface_ip(output: &str) -> Result<Vec<&str>, Error> {
let output = output.trim(); let output = output.trim();
if output.is_empty() { if output.is_empty() {
return Ok(None); return Ok(Vec::new());
} }
if let Some(ip) = output.split_ascii_whitespace().nth(3) { let mut res = Vec::new();
Ok(Some(ip)) for line in output.lines() {
if let Some(ip) = line.split_ascii_whitespace().nth(3) {
res.push(ip)
} else { } else {
Err(Error::new( return Err(Error::new(
eyre!("malformed output from `ip`"), eyre!("malformed output from `ip`"),
crate::ErrorKind::Network, crate::ErrorKind::Network,
)) ));
} }
}
Ok(res)
} }
pub async fn get_iface_ipv4_addr(iface: &str) -> Result<Option<(Ipv4Addr, Ipv4Net)>, Error> { pub async fn get_iface_ipv4_addr(iface: &str) -> Result<Option<(Ipv4Addr, Ipv4Net)>, Error> {
@@ -38,7 +43,9 @@ pub async fn get_iface_ipv4_addr(iface: &str) -> Result<Option<(Ipv4Addr, Ipv4Ne
.invoke(crate::ErrorKind::Network) .invoke(crate::ErrorKind::Network)
.await?, .await?,
)?)? )?)?
.into_iter()
.map(|s| Ok::<_, Error>((s.split("/").next().unwrap().parse()?, s.parse()?))) .map(|s| Ok::<_, Error>((s.split("/").next().unwrap().parse()?, s.parse()?)))
.next()
.transpose()?) .transpose()?)
} }
@@ -53,6 +60,8 @@ pub async fn get_iface_ipv6_addr(iface: &str) -> Result<Option<(Ipv6Addr, Ipv6Ne
.invoke(crate::ErrorKind::Network) .invoke(crate::ErrorKind::Network)
.await?, .await?,
)?)? )?)?
.into_iter()
.find(|ip| !ip.starts_with("fe80::"))
.map(|s| Ok::<_, Error>((s.split("/").next().unwrap().parse()?, s.parse()?))) .map(|s| Ok::<_, Error>((s.split("/").next().unwrap().parse()?, s.parse()?)))
.transpose()?) .transpose()?)
} }
@@ -121,3 +130,37 @@ impl<T> hyper::server::accept::Accept for SingleAccept<T> {
std::task::Poll::Ready(self.project().0.take().map(Ok)) std::task::Poll::Ready(self.project().0.take().map(Ok))
} }
} }
pub struct TcpListeners {
listeners: Vec<TcpListener>,
}
impl TcpListeners {
pub fn new(listeners: impl IntoIterator<Item = TcpListener>) -> Self {
Self {
listeners: listeners.into_iter().collect(),
}
}
pub async fn accept(&self) -> std::io::Result<(TcpStream, SocketAddr)> {
futures::future::select_all(self.listeners.iter().map(|l| Box::pin(l.accept())))
.await
.0
}
}
impl hyper::server::accept::Accept for TcpListeners {
type Conn = TcpStream;
type Error = std::io::Error;
fn poll_accept(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
for listener in self.listeners.iter() {
let poll = listener.poll_accept(cx);
if poll.is_ready() {
return poll.map(Some);
}
}
std::task::Poll::Pending
}
}

View File

@@ -18,7 +18,7 @@ use tokio_rustls::{LazyConfigAcceptor, TlsConnector};
use crate::net::keys::Key; use crate::net::keys::Key;
use crate::net::ssl::SslManager; use crate::net::ssl::SslManager;
use crate::net::utils::SingleAccept; use crate::net::utils::{SingleAccept, TcpListeners};
use crate::util::io::BackTrackingReader; use crate::util::io::BackTrackingReader;
use crate::Error; use crate::Error;
@@ -88,21 +88,20 @@ struct VHostServer {
impl VHostServer { impl VHostServer {
async fn new(port: u16, ssl: Arc<SslManager>) -> Result<Self, Error> { async fn new(port: u16, ssl: Arc<SslManager>) -> Result<Self, Error> {
// check if port allowed // check if port allowed
let listener = TcpListener::bind( let listeners = TcpListeners::new([
[ TcpListener::bind(SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), port))
SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 80),
SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 80),
]
.as_ref(),
)
.await .await
.with_kind(crate::ErrorKind::Network)?; .with_kind(crate::ErrorKind::Network)?,
TcpListener::bind(SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), port))
.await
.with_kind(crate::ErrorKind::Network)?,
]);
let mapping = Arc::new(RwLock::new(BTreeMap::new())); let mapping = Arc::new(RwLock::new(BTreeMap::new()));
Ok(Self { Ok(Self {
mapping: Arc::downgrade(&mapping), mapping: Arc::downgrade(&mapping),
_thread: tokio::spawn(async move { _thread: tokio::spawn(async move {
loop { loop {
match listener.accept().await { match listeners.accept().await {
Ok((stream, _)) => { Ok((stream, _)) => {
let mut stream = BackTrackingReader::new(stream); let mut stream = BackTrackingReader::new(stream);
stream.start_buffering(); stream.start_buffering();