use std::any::Any; use std::collections::{BTreeMap, BTreeSet}; use std::fmt; use std::net::{IpAddr, SocketAddr, SocketAddrV6}; use std::sync::{Arc, Weak}; use std::task::{Poll, ready}; use async_acme::acme::ACME_TLS_ALPN_NAME; use clap::Parser; use color_eyre::eyre::eyre; use futures::FutureExt; use futures::future::BoxFuture; use imbl::OrdMap; use imbl_value::{InOMap, InternedString}; use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn, from_fn_async}; use serde::{Deserialize, Serialize}; use tokio::net::{TcpListener, TcpStream}; use tokio_rustls::TlsConnector; use tokio_rustls::rustls::crypto::CryptoProvider; use tokio_rustls::rustls::pki_types::ServerName; use tokio_rustls::rustls::server::ClientHello; use tokio_rustls::rustls::{ClientConfig, ServerConfig}; use tracing::instrument; use ts_rs::TS; use visit_rs::Visit; use crate::context::{CliContext, RpcContext}; use crate::db::model::Database; use crate::db::model::public::{AcmeSettings, NetworkInterfaceInfo}; use crate::db::{DbAccessByKey, DbAccessMut}; use crate::net::acme::{ AcmeCertStore, AcmeProvider, AcmeTlsAlpnCache, AcmeTlsHandler, GetAcmeProvider, }; use crate::net::gateway::{ GatewayInfo, NetworkInterfaceController, NetworkInterfaceListenerAcceptMetadata, }; use crate::net::ssl::{CertStore, RootCaTlsHandler}; use crate::net::tls::{ ChainedHandler, TlsHandlerAction, TlsHandlerWrapper, TlsListener, TlsMetadata, WrapTlsHandler, }; use crate::net::utils::{ipv6_is_link_local, is_private_ip}; use crate::net::web_server::{Accept, AcceptStream, ExtractVisitor, TcpMetadata, extract}; use crate::prelude::*; use crate::util::collections::EqSet; use crate::util::future::{NonDetachingJoinHandle, WeakFuture}; use crate::util::serde::{HandlerExtSerde, MaybeUtf8String, display_serializable}; use crate::util::sync::{SyncMutex, Watch}; use crate::{GatewayId, ResultExt}; #[derive(Debug, Clone, Deserialize, Serialize, HasModel, TS)] #[serde(rename_all = "camelCase")] #[model = "Model"] #[ts(export)] pub struct PassthroughInfo { #[ts(type = "string")] pub hostname: InternedString, pub listen_port: u16, #[ts(type = "string")] pub backend: SocketAddr, #[ts(type = "string[]")] pub public_gateways: BTreeSet, #[ts(type = "string[]")] pub private_ips: BTreeSet, } #[derive(Debug, Clone, Deserialize, Serialize, Parser)] #[group(skip)] #[serde(rename_all = "kebab-case")] struct AddPassthroughParams { #[arg(long)] pub hostname: InternedString, #[arg(long)] pub listen_port: u16, #[arg(long)] pub backend: SocketAddr, #[arg(long)] pub public_gateway: Vec, #[arg(long)] pub private_ip: Vec, } #[derive(Debug, Clone, Deserialize, Serialize, Parser)] #[group(skip)] #[serde(rename_all = "kebab-case")] struct RemovePassthroughParams { #[arg(long)] pub hostname: InternedString, #[arg(long)] pub listen_port: u16, } pub fn vhost_api() -> ParentHandler { ParentHandler::new() .subcommand( "dump-table", from_fn(dump_table) .with_display_serializable() .with_custom_display_fn(|HandlerArgs { params, .. }, res| { use prettytable::*; if let Some(format) = params.format { display_serializable(format, res)?; return Ok::<_, Error>(()); } let mut table = Table::new(); table.add_row(row![bc => "FROM", "TO", "ACTIVE"]); for (external, targets) in res { for (host, targets) in targets { for (idx, target) in targets.into_iter().enumerate() { table.add_row(row![ format!( "{}:{}", host.as_ref().map(|s| &**s).unwrap_or("*"), external.0 ), target, idx == 0 ]); } } } table.print_tty(false)?; Ok(()) }) .with_about("about.dump-vhost-proxy-table") .with_call_remote::(), ) .subcommand( "add-passthrough", from_fn_async(add_passthrough) .no_display() .with_about("about.add-vhost-passthrough") .with_call_remote::(), ) .subcommand( "remove-passthrough", from_fn_async(remove_passthrough) .no_display() .with_about("about.remove-vhost-passthrough") .with_call_remote::(), ) .subcommand( "list-passthrough", from_fn(list_passthrough) .with_display_serializable() .with_about("about.list-vhost-passthrough") .with_call_remote::(), ) } fn dump_table( ctx: RpcContext, ) -> Result, BTreeMap>, EqSet>>, Error> { Ok(ctx.net_controller.vhost.dump_table()) } async fn add_passthrough( ctx: RpcContext, AddPassthroughParams { hostname, listen_port, backend, public_gateway, private_ip, }: AddPassthroughParams, ) -> Result<(), Error> { let public_gateways: BTreeSet = public_gateway.into_iter().collect(); let private_ips: BTreeSet = private_ip.into_iter().collect(); ctx.net_controller.vhost.add_passthrough( hostname.clone(), listen_port, backend, public_gateways.clone(), private_ips.clone(), )?; ctx.db .mutate(|db| { let pts = db .as_public_mut() .as_server_info_mut() .as_network_mut() .as_passthroughs_mut(); let mut vec: Vec = pts.de()?; vec.retain(|p| !(p.hostname == hostname && p.listen_port == listen_port)); vec.push(PassthroughInfo { hostname, listen_port, backend, public_gateways, private_ips, }); pts.ser(&vec) }) .await .result?; Ok(()) } async fn remove_passthrough( ctx: RpcContext, RemovePassthroughParams { hostname, listen_port, }: RemovePassthroughParams, ) -> Result<(), Error> { ctx.net_controller .vhost .remove_passthrough(&hostname, listen_port); ctx.db .mutate(|db| { let pts = db .as_public_mut() .as_server_info_mut() .as_network_mut() .as_passthroughs_mut(); let mut vec: Vec = pts.de()?; vec.retain(|p| !(p.hostname == hostname && p.listen_port == listen_port)); pts.ser(&vec) }) .await .result?; Ok(()) } fn list_passthrough(ctx: RpcContext) -> Result, Error> { Ok(ctx.net_controller.vhost.list_passthrough()) } // not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353 struct PassthroughHandle { _rc: Arc<()>, backend: SocketAddr, public: BTreeSet, private: BTreeSet, } pub struct VHostController { db: TypedPatchDb, interfaces: Arc, crypto_provider: Arc, acme_cache: AcmeTlsAlpnCache, servers: SyncMutex>>, passthrough_handles: SyncMutex>, } impl VHostController { pub fn new( db: TypedPatchDb, interfaces: Arc, crypto_provider: Arc, passthroughs: Vec, ) -> Self { let controller = Self { db, interfaces, crypto_provider, acme_cache: Arc::new(SyncMutex::new(BTreeMap::new())), servers: SyncMutex::new(BTreeMap::new()), passthrough_handles: SyncMutex::new(BTreeMap::new()), }; for pt in passthroughs { if let Err(e) = controller.add_passthrough( pt.hostname, pt.listen_port, pt.backend, pt.public_gateways, pt.private_ips, ) { tracing::warn!("failed to restore passthrough: {e}"); } } controller } #[instrument(skip_all)] pub fn add( &self, hostname: Option, external: u16, target: DynVHostTarget, ) -> Result, Error> { self.servers.mutate(|writable| { let server = if let Some(server) = writable.remove(&external) { server } else { self.create_server(external) }; let rc = server.add(hostname, target); writable.insert(external, server); Ok(rc?) }) } fn create_server(&self, port: u16) -> VHostServer { let bind_reqs = Watch::new(VHostBindRequirements::default()); let listener = VHostBindListener { ip_info: self.interfaces.watcher.subscribe(), port, bind_reqs: bind_reqs.clone_unseen(), listeners: BTreeMap::new(), }; VHostServer::new( listener, bind_reqs, self.db.clone(), self.crypto_provider.clone(), self.acme_cache.clone(), ) } pub fn add_passthrough( &self, hostname: InternedString, port: u16, backend: SocketAddr, public: BTreeSet, private: BTreeSet, ) -> Result<(), Error> { let target = ProxyTarget { public: public.clone(), private: private.clone(), acme: None, addr: backend, add_x_forwarded_headers: false, connect_ssl: Err(AlpnInfo::Reflect), passthrough: true, }; let rc = self.add(Some(hostname.clone()), port, DynVHostTarget::new(target))?; self.passthrough_handles.mutate(|h| { h.insert( (hostname, port), PassthroughHandle { _rc: rc, backend, public, private, }, ); }); Ok(()) } pub fn remove_passthrough(&self, hostname: &InternedString, port: u16) { self.passthrough_handles .mutate(|h| h.remove(&(hostname.clone(), port))); self.gc(Some(hostname.clone()), port); } pub fn list_passthrough(&self) -> Vec { self.passthrough_handles.peek(|h| { h.iter() .map(|((hostname, port), handle)| PassthroughInfo { hostname: hostname.clone(), listen_port: *port, backend: handle.backend, public_gateways: handle.public.clone(), private_ips: handle.private.clone(), }) .collect() }) } pub fn dump_table( &self, ) -> BTreeMap, BTreeMap>, EqSet>> { self.servers.peek(|s| { s.iter() .map(|(k, v)| { ( JsonKey::new(*k), v.mapping.peek(|m| { m.iter() .map(|(k, v)| { ( JsonKey::new(k.clone()), v.iter() .filter(|(_, v)| v.strong_count() > 0) .map(|(k, _)| format!("{k:#?}")) .collect(), ) }) .collect() }), ) }) .collect() }) } #[instrument(skip_all)] pub fn gc(&self, hostname: Option, external: u16) { self.servers.mutate(|writable| { if let Some(server) = writable.remove(&external) { server.gc(hostname); if !server.is_empty() { writable.insert(external, server); } } }) } } /// Union of all ProxyTargets' bind requirements for a VHostServer. #[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct VHostBindRequirements { pub public_gateways: BTreeSet, pub private_ips: BTreeSet, } fn compute_bind_reqs(mapping: &Mapping) -> VHostBindRequirements { let mut reqs = VHostBindRequirements::default(); for (_, targets) in mapping { for (target, rc) in targets { if rc.strong_count() > 0 { let (pub_gw, priv_ip) = target.0.bind_requirements(); reqs.public_gateways.extend(pub_gw); reqs.private_ips.extend(priv_ip); } } } reqs } /// Listener that manages its own TCP listeners with IP-level precision. /// Binds ALL IPs of public gateways and ONLY matching private IPs. pub struct VHostBindListener { ip_info: Watch>, port: u16, bind_reqs: Watch, listeners: BTreeMap, } fn update_vhost_listeners( listeners: &mut BTreeMap, port: u16, ip_info: &OrdMap, reqs: &VHostBindRequirements, ) -> Result<(), Error> { let mut keep = BTreeSet::::new(); for (gw_id, info) in ip_info { if let Some(ip_info) = &info.ip_info { for ipnet in &ip_info.subnets { let ip = ipnet.addr(); let should_bind = reqs.public_gateways.contains(gw_id) || reqs.private_ips.contains(&ip); if should_bind { let addr = match ip { IpAddr::V6(ip6) => SocketAddrV6::new( ip6, port, 0, if ipv6_is_link_local(ip6) { ip_info.scope_id } else { 0 }, ) .into(), ip => SocketAddr::new(ip, port), }; keep.insert(addr); if let Some((_, existing_info)) = listeners.get_mut(&addr) { *existing_info = GatewayInfo { id: gw_id.clone(), info: info.clone(), }; } else { let tcp = TcpListener::from_std( mio::net::TcpListener::bind(addr) .with_kind(ErrorKind::Network)? .into(), ) .with_kind(ErrorKind::Network)?; listeners.insert( addr, ( tcp, GatewayInfo { id: gw_id.clone(), info: info.clone(), }, ), ); } } } } } listeners.retain(|key, _| keep.contains(key)); Ok(()) } impl Accept for VHostBindListener { type Metadata = NetworkInterfaceListenerAcceptMetadata; fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll> { // Update listeners when ip_info or bind_reqs change while self.ip_info.poll_changed(cx).is_ready() || self.bind_reqs.poll_changed(cx).is_ready() { let reqs = self.bind_reqs.read_and_mark_seen(); let listeners = &mut self.listeners; let port = self.port; self.ip_info.peek_and_mark_seen(|ip_info| { update_vhost_listeners(listeners, port, ip_info, &reqs) })?; } // Poll each listener for incoming connections for (&addr, (listener, gw_info)) in &self.listeners { match listener.poll_accept(cx) { Poll::Ready(Ok((stream, peer_addr))) => { if let Err(e) = socket2::SockRef::from(&stream).set_keepalive(true) { tracing::error!("Failed to set tcp keepalive: {e}"); tracing::debug!("{e:?}"); } return Poll::Ready(Ok(( NetworkInterfaceListenerAcceptMetadata { inner: TcpMetadata { local_addr: addr, peer_addr, }, info: gw_info.clone(), }, Box::pin(stream), ))); } Poll::Ready(Err(e)) => { tracing::trace!("VHostBindListener accept error on {addr}: {e}"); } Poll::Pending => {} } } Poll::Pending } } pub trait VHostTarget: std::fmt::Debug + Eq { type PreprocessRes: Send + 'static; #[allow(unused_variables)] fn filter(&self, metadata: &::Metadata) -> bool { true } fn acme(&self) -> Option<&AcmeProvider> { None } /// Returns (public_gateways, private_ips) this target needs the listener to bind on. fn bind_requirements(&self) -> (BTreeSet, BTreeSet) { (BTreeSet::new(), BTreeSet::new()) } fn is_passthrough(&self) -> bool { false } fn preprocess<'a>( &'a self, prev: ServerConfig, hello: &'a ClientHello<'a>, metadata: &'a ::Metadata, ) -> impl Future> + Send + 'a; fn handle_stream( &self, stream: AcceptStream, metadata: TlsMetadata<::Metadata>, prev: Self::PreprocessRes, rc: Weak<()>, ); } pub trait DynVHostTargetT: std::fmt::Debug + Any { fn filter(&self, metadata: &::Metadata) -> bool; fn acme(&self) -> Option<&AcmeProvider>; fn bind_requirements(&self) -> (BTreeSet, BTreeSet); fn is_passthrough(&self) -> bool; fn preprocess<'a>( &'a self, prev: ServerConfig, hello: &'a ClientHello<'a>, metadata: &'a ::Metadata, ) -> BoxFuture<'a, Option<(ServerConfig, Box)>> where ::Metadata: Visit>; fn handle_stream( &self, stream: AcceptStream, metadata: TlsMetadata<::Metadata>, prev: Box, rc: Weak<()>, ); fn eq(&self, other: &dyn DynVHostTargetT) -> bool; } impl + 'static> DynVHostTargetT for T { fn filter(&self, metadata: &::Metadata) -> bool { VHostTarget::filter(self, metadata) } fn acme(&self) -> Option<&AcmeProvider> { VHostTarget::acme(self) } fn is_passthrough(&self) -> bool { VHostTarget::is_passthrough(self) } fn bind_requirements(&self) -> (BTreeSet, BTreeSet) { VHostTarget::bind_requirements(self) } fn preprocess<'a>( &'a self, prev: ServerConfig, hello: &'a ClientHello<'a>, metadata: &'a ::Metadata, ) -> BoxFuture<'a, Option<(ServerConfig, Box)>> { VHostTarget::preprocess(self, prev, hello, metadata) .map(|o| o.map(|(cfg, res)| (cfg, Box::new(res) as Box))) .boxed() } fn handle_stream( &self, stream: AcceptStream, metadata: TlsMetadata<::Metadata>, prev: Box, rc: Weak<()>, ) { if let Ok(prev) = prev.downcast() { VHostTarget::handle_stream(self, stream, metadata, *prev, rc); } } fn eq(&self, other: &dyn DynVHostTargetT) -> bool { Some(self) == (other as &dyn Any).downcast_ref() } } pub struct DynVHostTarget(Arc + Send + Sync>); impl DynVHostTarget { pub fn new + Send + Sync + 'static>(target: T) -> Self { Self(Arc::new(target)) } } impl Clone for DynVHostTarget { fn clone(&self) -> Self { Self(self.0.clone()) } } impl std::fmt::Debug for DynVHostTarget { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } impl PartialEq for DynVHostTarget { fn eq(&self, other: &Self) -> bool { self.0.eq(&*other.0) } } impl Eq for DynVHostTarget {} struct Preprocessed(DynVHostTarget, Weak<()>, Box); impl fmt::Debug for Preprocessed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { (self.0).0.fmt(f) } } impl DynVHostTarget { async fn into_preprocessed( self, rc: Weak<()>, prev: ServerConfig, hello: &ClientHello<'_>, metadata: &::Metadata, ) -> Option<(ServerConfig, Preprocessed)> where ::Metadata: Visit>, { let (cfg, res) = self.0.preprocess(prev, hello, metadata).await?; Some((cfg, Preprocessed(self, rc, res))) } } impl Preprocessed { fn finish(self, stream: AcceptStream, metadata: TlsMetadata<::Metadata>) { (self.0).0.handle_stream(stream, metadata, self.2, self.1); } } #[derive(Clone)] pub struct ProxyTarget { pub public: BTreeSet, pub private: BTreeSet, pub acme: Option, pub addr: SocketAddr, pub add_x_forwarded_headers: bool, pub connect_ssl: Result, AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn pub passthrough: bool, } impl PartialEq for ProxyTarget { fn eq(&self, other: &Self) -> bool { self.public == other.public && self.private == other.private && self.acme == other.acme && self.addr == other.addr && self.passthrough == other.passthrough && self.connect_ssl.as_ref().map(Arc::as_ptr) == other.connect_ssl.as_ref().map(Arc::as_ptr) } } impl Eq for ProxyTarget {} impl fmt::Debug for ProxyTarget { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ProxyTarget") .field("public", &self.public) .field("private", &self.private) .field("acme", &self.acme) .field("addr", &self.addr) .field("add_x_forwarded_headers", &self.add_x_forwarded_headers) .field("connect_ssl", &self.connect_ssl.as_ref().map(|_| ())) .field("passthrough", &self.passthrough) .finish() } } impl VHostTarget for ProxyTarget where A: Accept + 'static, ::Metadata: Visit> + Visit> + Clone + Send + Sync, { type PreprocessRes = AcceptStream; fn filter(&self, metadata: &::Metadata) -> bool { let gw = extract::(metadata); let tcp = extract::(metadata); let (Some(gw), Some(tcp)) = (gw, tcp) else { return false; }; let Some(ip_info) = &gw.info.ip_info else { return false; }; let src = tcp.peer_addr.ip(); let dst = tcp.local_addr.ip(); self.public.contains(&gw.id) || (self.private.contains(&dst) && (ip_info.subnets.iter().any(|s| s.contains(&src)) || is_private_ip(src))) } fn acme(&self) -> Option<&AcmeProvider> { self.acme.as_ref() } fn bind_requirements(&self) -> (BTreeSet, BTreeSet) { (self.public.clone(), self.private.clone()) } fn is_passthrough(&self) -> bool { self.passthrough } async fn preprocess<'a>( &'a self, mut prev: ServerConfig, hello: &'a ClientHello<'a>, _: &'a ::Metadata, ) -> Option<(ServerConfig, Self::PreprocessRes)> { let tcp_stream = TcpStream::connect(self.addr) .await .with_ctx(|_| (ErrorKind::Network, self.addr)) .log_err()?; if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) { tracing::error!("Failed to set tcp keepalive: {e}"); tracing::debug!("{e:?}"); } match &self.connect_ssl { Ok(client_cfg) => { let mut client_cfg = (&**client_cfg).clone(); client_cfg.alpn_protocols = hello .alpn() .into_iter() .flatten() .map(|x| x.to_vec()) .collect(); let target_stream = TlsConnector::from(Arc::new(client_cfg)) .connect_with( ServerName::IpAddress(self.addr.ip().into()), tcp_stream, |conn| { prev.alpn_protocols .extend(conn.alpn_protocol().into_iter().map(|p| p.to_vec())) }, ) .await .log_err()?; return Some((prev, Box::pin(target_stream))); } Err(AlpnInfo::Reflect) => { for alpn in hello.alpn().into_iter().flatten() { prev.alpn_protocols.push(alpn.to_vec()); } } Err(AlpnInfo::Specified(a)) => { for alpn in a { prev.alpn_protocols.push(alpn.0.clone()); } } } Some((prev, Box::pin(tcp_stream))) } fn handle_stream( &self, mut stream: AcceptStream, metadata: TlsMetadata<::Metadata>, mut prev: Self::PreprocessRes, rc: Weak<()>, ) { let add_x_forwarded_headers = self.add_x_forwarded_headers; tokio::spawn(async move { WeakFuture::new(rc, async move { if add_x_forwarded_headers { crate::net::http::run_http_proxy( stream, prev, metadata.tls_info.alpn, extract::(&metadata.inner).map(|m| m.peer_addr.ip()), ) .await .ok(); } else { tokio::io::copy_bidirectional(&mut stream, &mut prev) .await .ok(); } }) .await }); } } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, TS)] #[serde(rename_all = "camelCase")] #[ts(export)] pub enum AlpnInfo { Reflect, Specified(Vec), } impl Default for AlpnInfo { fn default() -> Self { Self::Reflect } } type Mapping = BTreeMap, InOMap, Weak<()>>>; pub struct GetVHostAcmeProvider(pub Watch>); impl Clone for GetVHostAcmeProvider { fn clone(&self) -> Self { Self(self.0.clone()) } } impl GetAcmeProvider for GetVHostAcmeProvider { async fn get_provider<'a, 'b: 'a>( &'b self, san_info: &'a BTreeSet, ) -> Option + Send + 'b> { self.0.peek(|m| -> Option { san_info .iter() .fold(Some::>(None), |acc, x| { let acc = acc?; if x.parse::().is_ok() { return Some(acc); } let (t, _) = m .get(&Some(x.clone()))? .iter() .find(|(_, rc)| rc.strong_count() > 0)?; let acme = t.0.acme()?; Some(if let Some(acc) = acc { if acme == acc { // all must match Some(acme) } else { None } } else { Some(acme) }) }) .flatten() .cloned() }) } } pub struct VHostConnector(Watch>, Option>); impl Clone for VHostConnector { fn clone(&self) -> Self { Self(self.0.clone(), None) } } impl WrapTlsHandler for VHostConnector where A: Accept + 'static, ::Metadata: Visit> + Visit> + Send + Sync, { async fn wrap<'a>( &'a mut self, prev: ServerConfig, hello: &'a ClientHello<'a>, metadata: &'a ::Metadata, ) -> Option where Self: 'a, { if hello .alpn() .into_iter() .flatten() .any(|a| a == ACME_TLS_ALPN_NAME) { return Some(TlsHandlerAction::Tls(prev)); } let (target, rc) = self.0.peek(|m| { m.get(&hello.server_name().map(InternedString::from)) .or_else(|| m.get(&None)) .into_iter() .flatten() .filter(|(_, rc)| rc.strong_count() > 0) .find(|(t, _)| t.0.filter(metadata)) .map(|(t, rc)| (t.clone(), rc.clone())) })?; let is_pt = target.0.is_passthrough(); let (prev, store) = target.into_preprocessed(rc, prev, hello, metadata).await?; self.1 = Some(store); if is_pt { Some(TlsHandlerAction::Passthrough) } else { Some(TlsHandlerAction::Tls(prev)) } } } struct VHostListener( TlsListener< A, TlsHandlerWrapper< ChainedHandler>>, RootCaTlsHandler>, VHostConnector, >, >, ) where for<'a> M: HasModel> + DbAccessMut + DbAccessMut + DbAccessByKey = &'a AcmeProvider> + Send + Sync + 'static, A: Accept + 'static, ::Metadata: Visit> + Visit> + Clone + Send + Sync + 'static; struct VHostListenerMetadata { inner: TlsMetadata, preprocessed: Preprocessed, } impl fmt::Debug for VHostListenerMetadata { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("VHostListenerMetadata") .field("inner", &self.inner) .field("preprocessed", &self.preprocessed) .finish() } } impl Accept for VHostListener where for<'a> M: HasModel> + DbAccessMut + DbAccessMut + DbAccessByKey = &'a AcmeProvider> + Send + Sync + 'static, A: Accept + 'static, ::Metadata: Visit> + Visit> + Clone + Send + Sync + 'static, { type Metadata = VHostListenerMetadata; fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll> { let (metadata, stream) = ready!(self.0.poll_accept(cx)?); let preprocessed = self.0.tls_handler.wrapper.1.take(); Poll::Ready(Ok(( VHostListenerMetadata { inner: metadata, preprocessed: preprocessed.ok_or_else(|| { Error::new( eyre!("tlslistener yielded but preprocessed isn't set"), ErrorKind::Incoherent, ) })?, }, stream, ))) } } impl VHostListener where for<'a> M: HasModel> + DbAccessMut + DbAccessMut + DbAccessByKey = &'a AcmeProvider> + Send + Sync + 'static, A: Accept + 'static, ::Metadata: Visit> + Visit> + Clone + Send + Sync + 'static, { async fn handle_next(&mut self) -> Result<(), Error> { let (metadata, stream) = futures::future::poll_fn(|cx| self.poll_accept(cx)).await?; metadata.preprocessed.finish(stream, metadata.inner); Ok(()) } } struct VHostServer { mapping: Watch>, bind_reqs: Watch, _thread: NonDetachingJoinHandle<()>, } impl VHostServer { #[instrument(skip_all)] fn new( listener: A, bind_reqs: Watch, db: TypedPatchDb, crypto_provider: Arc, acme_cache: AcmeTlsAlpnCache, ) -> Self where for<'a> M: HasModel> + DbAccessMut + DbAccessMut + DbAccessByKey = &'a AcmeProvider> + Send + Sync + 'static, A: Accept + Send + 'static, ::Metadata: Visit> + Visit> + Clone + Send + Sync + 'static, { let mapping = Watch::new(BTreeMap::new()); Self { mapping: mapping.clone(), bind_reqs, _thread: tokio::spawn(async move { let mut listener = VHostListener(TlsListener::new( listener, TlsHandlerWrapper { inner: ChainedHandler( Arc::new(AcmeTlsHandler { db: db.clone(), acme_cache, crypto_provider: crypto_provider.clone(), get_provider: GetVHostAcmeProvider(mapping.clone()), in_progress: Watch::new(BTreeMap::new()), }), RootCaTlsHandler { db, crypto_provider, }, ), wrapper: VHostConnector(mapping, None), }, )); loop { if let Err(e) = listener.handle_next().await { tracing::trace!("VHostServer: failed to accept connection: {e}"); tracing::trace!("{e:?}"); } } }) .into(), } } fn add( &self, hostname: Option, target: DynVHostTarget, ) -> Result, Error> { let target = target.into(); let mut res = Ok(Arc::new(())); self.mapping.send_if_modified(|writable| { let mut changed = false; let mut targets = writable.remove(&hostname).unwrap_or_default(); let rc = if let Some(rc) = Weak::upgrade(&targets.remove(&target).unwrap_or_default()) { rc } else { changed = true; Arc::new(()) }; targets.retain(|_, rc| rc.strong_count() > 0); targets.insert(target, Arc::downgrade(&rc)); writable.insert(hostname, targets); res = Ok(rc); if changed { self.update_bind_reqs(writable); } changed }); if self.mapping.watcher_count() > 1 { res } else { Err(Error::new( eyre!("VHost Service Thread has exited"), crate::ErrorKind::Network, )) } } fn gc(&self, hostname: Option) { self.mapping.send_if_modified(|writable| { let mut targets = writable.remove(&hostname).unwrap_or_default(); let pre = targets.len(); targets = targets .into_iter() .filter(|(_, rc)| rc.strong_count() > 0) .collect(); let post = targets.len(); if !targets.is_empty() { writable.insert(hostname, targets); } if pre != post { self.update_bind_reqs(writable); } pre == post }); } fn update_bind_reqs(&self, mapping: &Mapping) { let new_reqs = compute_bind_reqs(mapping); self.bind_reqs.send_if_modified(|reqs| { if *reqs != new_reqs { *reqs = new_reqs; true } else { false } }); } fn is_empty(&self) -> bool { self.mapping.peek(|m| m.is_empty()) } }