Files
start-os/core/startos/src/net/vhost.rs
Aiden McClelland 83133ced6a consolidate crates
2025-12-17 21:15:24 -07:00

758 lines
25 KiB
Rust

use std::any::Any;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt;
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, Weak};
use std::task::{Poll, ready};
use async_acme::acme::ACME_TLS_ALPN_NAME;
use color_eyre::eyre::eyre;
use futures::FutureExt;
use futures::future::BoxFuture;
use crate::util::future::NonDetachingJoinHandle;
use imbl_value::{InOMap, InternedString};
use crate::ResultExt;
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn};
use serde::{Deserialize, Serialize};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use tokio_rustls::rustls::crypto::CryptoProvider;
use tokio_rustls::rustls::pki_types::ServerName;
use tokio_rustls::rustls::server::ClientHello;
use tokio_rustls::rustls::{ClientConfig, ServerConfig};
use tracing::instrument;
use ts_rs::TS;
use visit_rs::Visit;
use crate::context::{CliContext, RpcContext};
use crate::db::model::Database;
use crate::db::model::public::AcmeSettings;
use crate::db::{DbAccessByKey, DbAccessMut};
use crate::net::acme::{
AcmeCertStore, AcmeProvider, AcmeTlsAlpnCache, AcmeTlsHandler, GetAcmeProvider,
};
use crate::net::gateway::{
AnyFilter, BindTcp, DynInterfaceFilter, GatewayInfo, InterfaceFilter,
NetworkInterfaceController, NetworkInterfaceListener,
};
use crate::net::ssl::{CertStore, RootCaTlsHandler};
use crate::net::tls::{
ChainedHandler, TlsHandlerWrapper, TlsListener, TlsMetadata, WrapTlsHandler,
};
use crate::net::web_server::{Accept, AcceptStream, ExtractVisitor, TcpMetadata, extract};
use crate::prelude::*;
use crate::util::collections::EqSet;
use crate::util::future::WeakFuture;
use crate::util::serde::{HandlerExtSerde, MaybeUtf8String, display_serializable};
use crate::util::sync::{SyncMutex, Watch};
pub fn vhost_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new().subcommand(
"dump-table",
from_fn(|ctx: RpcContext| Ok(ctx.net_controller.vhost.dump_table()))
.with_display_serializable()
.with_custom_display_fn(|HandlerArgs { params, .. }, res| {
use prettytable::*;
if let Some(format) = params.format {
display_serializable(format, res)?;
return Ok::<_, Error>(());
}
let mut table = Table::new();
table.add_row(row![bc => "FROM", "TO", "ACTIVE"]);
for (external, targets) in res {
for (host, targets) in targets {
for (idx, target) in targets.into_iter().enumerate() {
table.add_row(row![
format!(
"{}:{}",
host.as_ref().map(|s| &**s).unwrap_or("*"),
external.0
),
target,
idx == 0
]);
}
}
}
table.print_tty(false)?;
Ok(())
})
.with_call_remote::<CliContext>(),
)
}
// not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353
pub struct VHostController {
db: TypedPatchDb<Database>,
interfaces: Arc<NetworkInterfaceController>,
crypto_provider: Arc<CryptoProvider>,
acme_cache: AcmeTlsAlpnCache,
servers: SyncMutex<BTreeMap<u16, VHostServer<NetworkInterfaceListener>>>,
}
impl VHostController {
pub fn new(
db: TypedPatchDb<Database>,
interfaces: Arc<NetworkInterfaceController>,
crypto_provider: Arc<CryptoProvider>,
) -> Self {
Self {
db,
interfaces,
crypto_provider,
acme_cache: Arc::new(SyncMutex::new(BTreeMap::new())),
servers: SyncMutex::new(BTreeMap::new()),
}
}
#[instrument(skip_all)]
pub fn add(
&self,
hostname: Option<InternedString>,
external: u16,
target: DynVHostTarget<NetworkInterfaceListener>,
) -> Result<Arc<()>, Error> {
self.servers.mutate(|writable| {
let server = if let Some(server) = writable.remove(&external) {
server
} else {
VHostServer::new(
self.interfaces.watcher.bind(BindTcp, external)?,
self.db.clone(),
self.crypto_provider.clone(),
self.acme_cache.clone(),
)
};
let rc = server.add(hostname, target);
writable.insert(external, server);
Ok(rc?)
})
}
pub fn dump_table(
&self,
) -> BTreeMap<JsonKey<u16>, BTreeMap<JsonKey<Option<InternedString>>, EqSet<String>>> {
self.servers.peek(|s| {
s.iter()
.map(|(k, v)| {
(
JsonKey::new(*k),
v.mapping.peek(|m| {
m.iter()
.map(|(k, v)| {
(
JsonKey::new(k.clone()),
v.iter()
.filter(|(_, v)| v.strong_count() > 0)
.map(|(k, _)| format!("{k:#?}"))
.collect(),
)
})
.collect()
}),
)
})
.collect()
})
}
#[instrument(skip_all)]
pub fn gc(&self, hostname: Option<InternedString>, external: u16) {
self.servers.mutate(|writable| {
if let Some(server) = writable.remove(&external) {
server.gc(hostname);
if !server.is_empty() {
writable.insert(external, server);
}
}
})
}
}
pub trait VHostTarget<A: Accept>: std::fmt::Debug + Eq {
type PreprocessRes: Send + 'static;
#[allow(unused_variables)]
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool {
true
}
fn acme(&self) -> Option<&AcmeProvider> {
None
}
fn preprocess<'a>(
&'a self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> impl Future<Output = Option<(ServerConfig, Self::PreprocessRes)>> + Send + 'a;
fn handle_stream(
&self,
stream: AcceptStream,
metadata: TlsMetadata<<A as Accept>::Metadata>,
prev: Self::PreprocessRes,
rc: Weak<()>,
);
}
pub trait DynVHostTargetT<A: Accept>: std::fmt::Debug + Any {
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool;
fn acme(&self) -> Option<&AcmeProvider>;
fn preprocess<'a>(
&'a self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> BoxFuture<'a, Option<(ServerConfig, Box<dyn Any + Send>)>>
where
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>;
fn handle_stream(
&self,
stream: AcceptStream,
metadata: TlsMetadata<<A as Accept>::Metadata>,
prev: Box<dyn Any + Send>,
rc: Weak<()>,
);
fn eq(&self, other: &dyn DynVHostTargetT<A>) -> bool;
}
impl<A: Accept, T: VHostTarget<A> + 'static> DynVHostTargetT<A> for T {
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool {
VHostTarget::filter(self, metadata)
}
fn acme(&self) -> Option<&AcmeProvider> {
VHostTarget::acme(self)
}
fn preprocess<'a>(
&'a self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> BoxFuture<'a, Option<(ServerConfig, Box<dyn Any + Send>)>> {
VHostTarget::preprocess(self, prev, hello, metadata)
.map(|o| o.map(|(cfg, res)| (cfg, Box::new(res) as Box<dyn Any + Send>)))
.boxed()
}
fn handle_stream(
&self,
stream: AcceptStream,
metadata: TlsMetadata<<A as Accept>::Metadata>,
prev: Box<dyn Any + Send>,
rc: Weak<()>,
) {
if let Ok(prev) = prev.downcast() {
VHostTarget::handle_stream(self, stream, metadata, *prev, rc);
}
}
fn eq(&self, other: &dyn DynVHostTargetT<A>) -> bool {
Some(self) == (other as &dyn Any).downcast_ref()
}
}
pub struct DynVHostTarget<A: Accept>(Arc<dyn DynVHostTargetT<A> + Send + Sync>);
impl<A: Accept> DynVHostTarget<A> {
pub fn new<T: VHostTarget<A> + Send + Sync + 'static>(target: T) -> Self {
Self(Arc::new(target))
}
}
impl<A: Accept> Clone for DynVHostTarget<A> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<A: Accept> std::fmt::Debug for DynVHostTarget<A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl<A: Accept + 'static> PartialEq for DynVHostTarget<A> {
fn eq(&self, other: &Self) -> bool {
self.0.eq(&*other.0)
}
}
impl<A: Accept + 'static> Eq for DynVHostTarget<A> {}
struct Preprocessed<A: Accept>(DynVHostTarget<A>, Weak<()>, Box<dyn Any + Send>);
impl<A: Accept> fmt::Debug for Preprocessed<A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(self.0).0.fmt(f)
}
}
impl<A: Accept + 'static> DynVHostTarget<A> {
async fn into_preprocessed(
self,
rc: Weak<()>,
prev: ServerConfig,
hello: &ClientHello<'_>,
metadata: &<A as Accept>::Metadata,
) -> Option<(ServerConfig, Preprocessed<A>)>
where
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>,
{
let (cfg, res) = self.0.preprocess(prev, hello, metadata).await?;
Some((cfg, Preprocessed(self, rc, res)))
}
}
impl<A: Accept + 'static> Preprocessed<A> {
fn finish(self, stream: AcceptStream, metadata: TlsMetadata<<A as Accept>::Metadata>) {
(self.0).0.handle_stream(stream, metadata, self.2, self.1);
}
}
#[derive(Clone)]
pub struct ProxyTarget {
pub filter: DynInterfaceFilter,
pub acme: Option<AcmeProvider>,
pub addr: SocketAddr,
pub add_x_forwarded_headers: bool,
pub connect_ssl: Result<Arc<ClientConfig>, AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn
}
impl PartialEq for ProxyTarget {
fn eq(&self, other: &Self) -> bool {
self.filter == other.filter
&& self.acme == other.acme
&& self.addr == other.addr
&& self.connect_ssl.as_ref().map(Arc::as_ptr)
== other.connect_ssl.as_ref().map(Arc::as_ptr)
}
}
impl Eq for ProxyTarget {}
impl fmt::Debug for ProxyTarget {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ProxyTarget")
.field("filter", &self.filter)
.field("acme", &self.acme)
.field("addr", &self.addr)
.field("add_x_forwarded_headers", &self.add_x_forwarded_headers)
.field("connect_ssl", &self.connect_ssl.as_ref().map(|_| ()))
.finish()
}
}
impl<A> VHostTarget<A> for ProxyTarget
where
A: Accept + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<GatewayInfo>>
+ Visit<ExtractVisitor<TcpMetadata>>
+ Clone
+ Send
+ Sync,
{
type PreprocessRes = AcceptStream;
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool {
let info = extract::<GatewayInfo, _>(metadata);
if info.is_none() {
tracing::warn!("No GatewayInfo on metadata");
}
info.as_ref()
.map_or(true, |i| self.filter.filter(&i.id, &i.info))
}
fn acme(&self) -> Option<&AcmeProvider> {
self.acme.as_ref()
}
async fn preprocess<'a>(
&'a self,
mut prev: ServerConfig,
hello: &'a ClientHello<'a>,
_: &'a <A as Accept>::Metadata,
) -> Option<(ServerConfig, Self::PreprocessRes)> {
let tcp_stream = TcpStream::connect(self.addr)
.await
.with_ctx(|_| (ErrorKind::Network, self.addr))
.log_err()?;
match &self.connect_ssl {
Ok(client_cfg) => {
let mut client_cfg = (&**client_cfg).clone();
client_cfg.alpn_protocols = hello
.alpn()
.into_iter()
.flatten()
.map(|x| x.to_vec())
.collect();
let target_stream = TlsConnector::from(Arc::new(client_cfg))
.connect_with(
ServerName::IpAddress(self.addr.ip().into()),
tcp_stream,
|conn| {
prev.alpn_protocols
.extend(conn.alpn_protocol().into_iter().map(|p| p.to_vec()))
},
)
.await
.log_err()?;
return Some((prev, Box::pin(target_stream)));
}
Err(AlpnInfo::Reflect) => {
for alpn in hello.alpn().into_iter().flatten() {
prev.alpn_protocols.push(alpn.to_vec());
}
}
Err(AlpnInfo::Specified(a)) => {
for alpn in a {
prev.alpn_protocols.push(alpn.0.clone());
}
}
}
Some((prev, Box::pin(tcp_stream)))
}
fn handle_stream(
&self,
mut stream: AcceptStream,
metadata: TlsMetadata<<A as Accept>::Metadata>,
mut prev: Self::PreprocessRes,
rc: Weak<()>,
) {
let add_x_forwarded_headers = self.add_x_forwarded_headers;
tokio::spawn(async move {
WeakFuture::new(rc, async move {
if add_x_forwarded_headers {
crate::net::http::run_http_proxy(
stream,
prev,
metadata.tls_info.alpn,
extract::<TcpMetadata, _>(&metadata.inner).map(|m| m.peer_addr.ip()),
)
.await
.ok();
} else {
tokio::io::copy_bidirectional(&mut stream, &mut prev)
.await
.ok();
}
})
.await
});
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub enum AlpnInfo {
Reflect,
Specified(Vec<MaybeUtf8String>),
}
impl Default for AlpnInfo {
fn default() -> Self {
Self::Reflect
}
}
type Mapping<A> = BTreeMap<Option<InternedString>, InOMap<DynVHostTarget<A>, Weak<()>>>;
pub struct GetVHostAcmeProvider<A: Accept + 'static>(pub Watch<Mapping<A>>);
impl<A: Accept + 'static> Clone for GetVHostAcmeProvider<A> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<A: Accept + 'static> GetAcmeProvider for GetVHostAcmeProvider<A> {
async fn get_provider<'a, 'b: 'a>(
&'b self,
san_info: &'a BTreeSet<InternedString>,
) -> Option<impl AsRef<AcmeProvider> + Send + 'b> {
self.0.peek(|m| -> Option<AcmeProvider> {
san_info
.iter()
.fold(Some::<Option<&AcmeProvider>>(None), |acc, x| {
let acc = acc?;
if x.parse::<IpAddr>().is_ok() {
return Some(acc);
}
let (t, _) = m
.get(&Some(x.clone()))?
.iter()
.find(|(_, rc)| rc.strong_count() > 0)?;
let acme = t.0.acme()?;
Some(if let Some(acc) = acc {
if acme == acc {
// all must match
Some(acme)
} else {
None
}
} else {
Some(acme)
})
})
.flatten()
.cloned()
})
}
}
pub struct VHostConnector<A: Accept + 'static>(Watch<Mapping<A>>, Option<Preprocessed<A>>);
impl<A: Accept + 'static> Clone for VHostConnector<A> {
fn clone(&self) -> Self {
Self(self.0.clone(), None)
}
}
impl<A> WrapTlsHandler<A> for VHostConnector<A>
where
A: Accept + 'static,
<A as Accept>::Metadata:
Visit<ExtractVisitor<GatewayInfo>> + Visit<ExtractVisitor<TcpMetadata>> + Send + Sync,
{
async fn wrap<'a>(
&'a mut self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig>
where
Self: 'a,
{
if hello
.alpn()
.into_iter()
.flatten()
.any(|a| a == ACME_TLS_ALPN_NAME)
{
return Some(prev);
}
let (target, rc) = self.0.peek(|m| {
m.get(&hello.server_name().map(InternedString::from))
.into_iter()
.flatten()
.filter(|(_, rc)| rc.strong_count() > 0)
.find(|(t, _)| t.0.filter(metadata))
.map(|(t, rc)| (t.clone(), rc.clone()))
})?;
let (prev, store) = target.into_preprocessed(rc, prev, hello, metadata).await?;
self.1 = Some(store);
Some(prev)
}
}
struct VHostListener<M, A>(
TlsListener<
A,
TlsHandlerWrapper<
ChainedHandler<Arc<AcmeTlsHandler<M, GetVHostAcmeProvider<A>>>, RootCaTlsHandler<M>>,
VHostConnector<A>,
>,
>,
)
where
for<'a> M: HasModel<Model = Model<M>>
+ DbAccessMut<CertStore>
+ DbAccessMut<AcmeCertStore>
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ Send
+ Sync,
A: Accept + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
+ Clone
+ Send
+ Sync
+ 'static;
struct VHostListenerMetadata<A: Accept> {
inner: TlsMetadata<A::Metadata>,
preprocessed: Preprocessed<A>,
}
impl<A: Accept> fmt::Debug for VHostListenerMetadata<A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("VHostListenerMetadata")
.field("inner", &self.inner)
.field("preprocessed", &self.preprocessed)
.finish()
}
}
impl<M, A> Accept for VHostListener<M, A>
where
for<'a> M: HasModel<Model = Model<M>>
+ DbAccessMut<CertStore>
+ DbAccessMut<AcmeCertStore>
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ Send
+ Sync
+ 'static,
A: Accept + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
+ Clone
+ Send
+ Sync
+ 'static,
{
type Metadata = VHostListenerMetadata<A>;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
let (metadata, stream) = ready!(self.0.poll_accept(cx)?);
let preprocessed = self.0.tls_handler.wrapper.1.take();
Poll::Ready(Ok((
VHostListenerMetadata {
inner: metadata,
preprocessed: preprocessed.ok_or_else(|| {
Error::new(
eyre!("tlslistener yielded but preprocessed isn't set"),
ErrorKind::Incoherent,
)
})?,
},
stream,
)))
}
}
impl<M, A> VHostListener<M, A>
where
for<'a> M: HasModel<Model = Model<M>>
+ DbAccessMut<CertStore>
+ DbAccessMut<AcmeCertStore>
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ Send
+ Sync
+ 'static,
A: Accept + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
+ Clone
+ Send
+ Sync
+ 'static,
{
async fn handle_next(&mut self) -> Result<(), Error> {
let (metadata, stream) = futures::future::poll_fn(|cx| self.poll_accept(cx)).await?;
metadata.preprocessed.finish(stream, metadata.inner);
Ok(())
}
}
struct VHostServer<A: Accept + 'static> {
mapping: Watch<Mapping<A>>,
_thread: NonDetachingJoinHandle<()>,
}
impl<'a> From<&'a BTreeMap<Option<InternedString>, BTreeMap<ProxyTarget, Weak<()>>>> for AnyFilter {
fn from(value: &'a BTreeMap<Option<InternedString>, BTreeMap<ProxyTarget, Weak<()>>>) -> Self {
Self(
value
.iter()
.flat_map(|(_, v)| {
v.iter()
.filter(|(_, r)| r.strong_count() > 0)
.map(|(t, _)| t.filter.clone())
})
.collect(),
)
}
}
impl<A: Accept> VHostServer<A> {
#[instrument(skip_all)]
fn new<M: HasModel>(
listener: A,
db: TypedPatchDb<M>,
crypto_provider: Arc<CryptoProvider>,
acme_cache: AcmeTlsAlpnCache,
) -> Self
where
for<'a> M: HasModel<Model = Model<M>>
+ DbAccessMut<CertStore>
+ DbAccessMut<AcmeCertStore>
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ Send
+ Sync
+ 'static,
A: Accept + Send + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
+ Clone
+ Send
+ Sync
+ 'static,
{
let mapping = Watch::new(BTreeMap::new());
Self {
mapping: mapping.clone(),
_thread: tokio::spawn(async move {
let mut listener = VHostListener(TlsListener::new(
listener,
TlsHandlerWrapper {
inner: ChainedHandler(
Arc::new(AcmeTlsHandler {
db: db.clone(),
acme_cache,
crypto_provider: crypto_provider.clone(),
get_provider: GetVHostAcmeProvider(mapping.clone()),
in_progress: Watch::new(BTreeSet::new()),
}),
RootCaTlsHandler {
db,
crypto_provider,
},
),
wrapper: VHostConnector(mapping, None),
},
));
loop {
if let Err(e) = listener.handle_next().await {
tracing::trace!("VHostServer: failed to accept connection: {e}");
tracing::trace!("{e:?}");
}
}
})
.into(),
}
}
fn add(
&self,
hostname: Option<InternedString>,
target: DynVHostTarget<A>,
) -> Result<Arc<()>, Error> {
let target = target.into();
let mut res = Ok(Arc::new(()));
self.mapping.send_if_modified(|writable| {
let mut changed = false;
let mut targets = writable.remove(&hostname).unwrap_or_default();
let rc = if let Some(rc) = Weak::upgrade(&targets.remove(&target).unwrap_or_default()) {
rc
} else {
changed = true;
Arc::new(())
};
targets.retain(|_, rc| rc.strong_count() > 0);
targets.insert(target, Arc::downgrade(&rc));
writable.insert(hostname, targets);
res = Ok(rc);
changed
});
if self.mapping.watcher_count() > 1 {
res
} else {
Err(Error::new(
eyre!("VHost Service Thread has exited"),
crate::ErrorKind::Network,
))
}
}
fn gc(&self, hostname: Option<InternedString>) {
self.mapping.send_if_modified(|writable| {
let mut targets = writable.remove(&hostname).unwrap_or_default();
let pre = targets.len();
targets = targets
.into_iter()
.filter(|(_, rc)| rc.strong_count() > 0)
.collect();
let post = targets.len();
if !targets.is_empty() {
writable.insert(hostname, targets);
}
pre == post
});
}
fn is_empty(&self) -> bool {
self.mapping.peek(|m| m.is_empty())
}
}