use std::any::Any; use std::collections::{BTreeMap, BTreeSet}; use std::fmt; use std::net::{IpAddr, SocketAddr}; use std::sync::{Arc, Weak}; use std::task::{Poll, ready}; use std::time::Duration; use async_acme::acme::ACME_TLS_ALPN_NAME; use color_eyre::eyre::eyre; use futures::FutureExt; use futures::future::BoxFuture; use imbl_value::{InOMap, InternedString}; use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn}; use serde::{Deserialize, Serialize}; use tokio::net::TcpStream; use tokio_rustls::TlsConnector; use tokio_rustls::rustls::crypto::CryptoProvider; use tokio_rustls::rustls::pki_types::ServerName; use tokio_rustls::rustls::server::ClientHello; use tokio_rustls::rustls::{ClientConfig, ServerConfig}; use tracing::instrument; use ts_rs::TS; use visit_rs::Visit; use crate::ResultExt; use crate::context::{CliContext, RpcContext}; use crate::db::model::Database; use crate::db::model::public::AcmeSettings; use crate::db::{DbAccessByKey, DbAccessMut}; use crate::net::acme::{ AcmeCertStore, AcmeProvider, AcmeTlsAlpnCache, AcmeTlsHandler, GetAcmeProvider, }; use crate::net::gateway::{ AnyFilter, BindTcp, DynInterfaceFilter, GatewayInfo, InterfaceFilter, NetworkInterfaceController, NetworkInterfaceListener, }; use crate::net::ssl::{CertStore, RootCaTlsHandler}; use crate::net::tls::{ ChainedHandler, TlsHandlerWrapper, TlsListener, TlsMetadata, WrapTlsHandler, }; use crate::net::web_server::{Accept, AcceptStream, ExtractVisitor, TcpMetadata, extract}; use crate::prelude::*; use crate::util::collections::EqSet; use crate::util::future::{NonDetachingJoinHandle, WeakFuture}; use crate::util::serde::{HandlerExtSerde, MaybeUtf8String, display_serializable}; use crate::util::sync::{SyncMutex, Watch}; pub fn vhost_api() -> ParentHandler { ParentHandler::new().subcommand( "dump-table", from_fn(|ctx: RpcContext| Ok(ctx.net_controller.vhost.dump_table())) .with_display_serializable() .with_custom_display_fn(|HandlerArgs { params, .. }, res| { use prettytable::*; if let Some(format) = params.format { display_serializable(format, res)?; return Ok::<_, Error>(()); } let mut table = Table::new(); table.add_row(row![bc => "FROM", "TO", "ACTIVE"]); for (external, targets) in res { for (host, targets) in targets { for (idx, target) in targets.into_iter().enumerate() { table.add_row(row![ format!( "{}:{}", host.as_ref().map(|s| &**s).unwrap_or("*"), external.0 ), target, idx == 0 ]); } } } table.print_tty(false)?; Ok(()) }) .with_call_remote::(), ) } // not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353 pub struct VHostController { db: TypedPatchDb, interfaces: Arc, crypto_provider: Arc, acme_cache: AcmeTlsAlpnCache, servers: SyncMutex>>, } impl VHostController { pub fn new( db: TypedPatchDb, interfaces: Arc, crypto_provider: Arc, ) -> Self { Self { db, interfaces, crypto_provider, acme_cache: Arc::new(SyncMutex::new(BTreeMap::new())), servers: SyncMutex::new(BTreeMap::new()), } } #[instrument(skip_all)] pub fn add( &self, hostname: Option, external: u16, target: DynVHostTarget, ) -> Result, Error> { self.servers.mutate(|writable| { let server = if let Some(server) = writable.remove(&external) { server } else { VHostServer::new( self.interfaces.watcher.bind(BindTcp, external)?, self.db.clone(), self.crypto_provider.clone(), self.acme_cache.clone(), ) }; let rc = server.add(hostname, target); writable.insert(external, server); Ok(rc?) }) } pub fn dump_table( &self, ) -> BTreeMap, BTreeMap>, EqSet>> { self.servers.peek(|s| { s.iter() .map(|(k, v)| { ( JsonKey::new(*k), v.mapping.peek(|m| { m.iter() .map(|(k, v)| { ( JsonKey::new(k.clone()), v.iter() .filter(|(_, v)| v.strong_count() > 0) .map(|(k, _)| format!("{k:#?}")) .collect(), ) }) .collect() }), ) }) .collect() }) } #[instrument(skip_all)] pub fn gc(&self, hostname: Option, external: u16) { self.servers.mutate(|writable| { if let Some(server) = writable.remove(&external) { server.gc(hostname); if !server.is_empty() { writable.insert(external, server); } } }) } } pub trait VHostTarget: std::fmt::Debug + Eq { type PreprocessRes: Send + 'static; #[allow(unused_variables)] fn filter(&self, metadata: &::Metadata) -> bool { true } fn acme(&self) -> Option<&AcmeProvider> { None } fn preprocess<'a>( &'a self, prev: ServerConfig, hello: &'a ClientHello<'a>, metadata: &'a ::Metadata, ) -> impl Future> + Send + 'a; fn handle_stream( &self, stream: AcceptStream, metadata: TlsMetadata<::Metadata>, prev: Self::PreprocessRes, rc: Weak<()>, ); } pub trait DynVHostTargetT: std::fmt::Debug + Any { fn filter(&self, metadata: &::Metadata) -> bool; fn acme(&self) -> Option<&AcmeProvider>; fn preprocess<'a>( &'a self, prev: ServerConfig, hello: &'a ClientHello<'a>, metadata: &'a ::Metadata, ) -> BoxFuture<'a, Option<(ServerConfig, Box)>> where ::Metadata: Visit>; fn handle_stream( &self, stream: AcceptStream, metadata: TlsMetadata<::Metadata>, prev: Box, rc: Weak<()>, ); fn eq(&self, other: &dyn DynVHostTargetT) -> bool; } impl + 'static> DynVHostTargetT for T { fn filter(&self, metadata: &::Metadata) -> bool { VHostTarget::filter(self, metadata) } fn acme(&self) -> Option<&AcmeProvider> { VHostTarget::acme(self) } fn preprocess<'a>( &'a self, prev: ServerConfig, hello: &'a ClientHello<'a>, metadata: &'a ::Metadata, ) -> BoxFuture<'a, Option<(ServerConfig, Box)>> { VHostTarget::preprocess(self, prev, hello, metadata) .map(|o| o.map(|(cfg, res)| (cfg, Box::new(res) as Box))) .boxed() } fn handle_stream( &self, stream: AcceptStream, metadata: TlsMetadata<::Metadata>, prev: Box, rc: Weak<()>, ) { if let Ok(prev) = prev.downcast() { VHostTarget::handle_stream(self, stream, metadata, *prev, rc); } } fn eq(&self, other: &dyn DynVHostTargetT) -> bool { Some(self) == (other as &dyn Any).downcast_ref() } } pub struct DynVHostTarget(Arc + Send + Sync>); impl DynVHostTarget { pub fn new + Send + Sync + 'static>(target: T) -> Self { Self(Arc::new(target)) } } impl Clone for DynVHostTarget { fn clone(&self) -> Self { Self(self.0.clone()) } } impl std::fmt::Debug for DynVHostTarget { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } impl PartialEq for DynVHostTarget { fn eq(&self, other: &Self) -> bool { self.0.eq(&*other.0) } } impl Eq for DynVHostTarget {} struct Preprocessed(DynVHostTarget, Weak<()>, Box); impl fmt::Debug for Preprocessed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { (self.0).0.fmt(f) } } impl DynVHostTarget { async fn into_preprocessed( self, rc: Weak<()>, prev: ServerConfig, hello: &ClientHello<'_>, metadata: &::Metadata, ) -> Option<(ServerConfig, Preprocessed)> where ::Metadata: Visit>, { let (cfg, res) = self.0.preprocess(prev, hello, metadata).await?; Some((cfg, Preprocessed(self, rc, res))) } } impl Preprocessed { fn finish(self, stream: AcceptStream, metadata: TlsMetadata<::Metadata>) { (self.0).0.handle_stream(stream, metadata, self.2, self.1); } } #[derive(Clone)] pub struct ProxyTarget { pub filter: DynInterfaceFilter, pub acme: Option, pub addr: SocketAddr, pub add_x_forwarded_headers: bool, pub connect_ssl: Result, AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn } impl PartialEq for ProxyTarget { fn eq(&self, other: &Self) -> bool { self.filter == other.filter && self.acme == other.acme && self.addr == other.addr && self.connect_ssl.as_ref().map(Arc::as_ptr) == other.connect_ssl.as_ref().map(Arc::as_ptr) } } impl Eq for ProxyTarget {} impl fmt::Debug for ProxyTarget { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ProxyTarget") .field("filter", &self.filter) .field("acme", &self.acme) .field("addr", &self.addr) .field("add_x_forwarded_headers", &self.add_x_forwarded_headers) .field("connect_ssl", &self.connect_ssl.as_ref().map(|_| ())) .finish() } } impl VHostTarget for ProxyTarget where A: Accept + 'static, ::Metadata: Visit> + Visit> + Clone + Send + Sync, { type PreprocessRes = AcceptStream; fn filter(&self, metadata: &::Metadata) -> bool { let info = extract::(metadata); if info.is_none() { tracing::warn!("No GatewayInfo on metadata"); } info.as_ref() .map_or(true, |i| self.filter.filter(&i.id, &i.info)) } fn acme(&self) -> Option<&AcmeProvider> { self.acme.as_ref() } async fn preprocess<'a>( &'a self, mut prev: ServerConfig, hello: &'a ClientHello<'a>, _: &'a ::Metadata, ) -> Option<(ServerConfig, Self::PreprocessRes)> { let tcp_stream = TcpStream::connect(self.addr) .await .with_ctx(|_| (ErrorKind::Network, self.addr)) .log_err()?; if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) { tracing::error!("Failed to set tcp keepalive: {e}"); tracing::debug!("{e:?}"); } match &self.connect_ssl { Ok(client_cfg) => { let mut client_cfg = (&**client_cfg).clone(); client_cfg.alpn_protocols = hello .alpn() .into_iter() .flatten() .map(|x| x.to_vec()) .collect(); let target_stream = TlsConnector::from(Arc::new(client_cfg)) .connect_with( ServerName::IpAddress(self.addr.ip().into()), tcp_stream, |conn| { prev.alpn_protocols .extend(conn.alpn_protocol().into_iter().map(|p| p.to_vec())) }, ) .await .log_err()?; return Some((prev, Box::pin(target_stream))); } Err(AlpnInfo::Reflect) => { for alpn in hello.alpn().into_iter().flatten() { prev.alpn_protocols.push(alpn.to_vec()); } } Err(AlpnInfo::Specified(a)) => { for alpn in a { prev.alpn_protocols.push(alpn.0.clone()); } } } Some((prev, Box::pin(tcp_stream))) } fn handle_stream( &self, mut stream: AcceptStream, metadata: TlsMetadata<::Metadata>, mut prev: Self::PreprocessRes, rc: Weak<()>, ) { let add_x_forwarded_headers = self.add_x_forwarded_headers; tokio::spawn(async move { WeakFuture::new(rc, async move { if add_x_forwarded_headers { crate::net::http::run_http_proxy( stream, prev, metadata.tls_info.alpn, extract::(&metadata.inner).map(|m| m.peer_addr.ip()), ) .await .ok(); } else { tokio::io::copy_bidirectional(&mut stream, &mut prev) .await .ok(); } }) .await }); } } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, TS)] #[serde(rename_all = "camelCase")] #[ts(export)] pub enum AlpnInfo { Reflect, Specified(Vec), } impl Default for AlpnInfo { fn default() -> Self { Self::Reflect } } type Mapping = BTreeMap, InOMap, Weak<()>>>; pub struct GetVHostAcmeProvider(pub Watch>); impl Clone for GetVHostAcmeProvider { fn clone(&self) -> Self { Self(self.0.clone()) } } impl GetAcmeProvider for GetVHostAcmeProvider { async fn get_provider<'a, 'b: 'a>( &'b self, san_info: &'a BTreeSet, ) -> Option + Send + 'b> { self.0.peek(|m| -> Option { san_info .iter() .fold(Some::>(None), |acc, x| { let acc = acc?; if x.parse::().is_ok() { return Some(acc); } let (t, _) = m .get(&Some(x.clone()))? .iter() .find(|(_, rc)| rc.strong_count() > 0)?; let acme = t.0.acme()?; Some(if let Some(acc) = acc { if acme == acc { // all must match Some(acme) } else { None } } else { Some(acme) }) }) .flatten() .cloned() }) } } pub struct VHostConnector(Watch>, Option>); impl Clone for VHostConnector { fn clone(&self) -> Self { Self(self.0.clone(), None) } } impl WrapTlsHandler for VHostConnector where A: Accept + 'static, ::Metadata: Visit> + Visit> + Send + Sync, { async fn wrap<'a>( &'a mut self, prev: ServerConfig, hello: &'a ClientHello<'a>, metadata: &'a ::Metadata, ) -> Option where Self: 'a, { if hello .alpn() .into_iter() .flatten() .any(|a| a == ACME_TLS_ALPN_NAME) { return Some(prev); } let (target, rc) = self.0.peek(|m| { m.get(&hello.server_name().map(InternedString::from)) .into_iter() .flatten() .filter(|(_, rc)| rc.strong_count() > 0) .find(|(t, _)| t.0.filter(metadata)) .map(|(t, rc)| (t.clone(), rc.clone())) })?; let (prev, store) = target.into_preprocessed(rc, prev, hello, metadata).await?; self.1 = Some(store); Some(prev) } } struct VHostListener( TlsListener< A, TlsHandlerWrapper< ChainedHandler>>, RootCaTlsHandler>, VHostConnector, >, >, ) where for<'a> M: HasModel> + DbAccessMut + DbAccessMut + DbAccessByKey = &'a AcmeProvider> + Send + Sync, A: Accept + 'static, ::Metadata: Visit> + Visit> + Clone + Send + Sync + 'static; struct VHostListenerMetadata { inner: TlsMetadata, preprocessed: Preprocessed, } impl fmt::Debug for VHostListenerMetadata { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("VHostListenerMetadata") .field("inner", &self.inner) .field("preprocessed", &self.preprocessed) .finish() } } impl Accept for VHostListener where for<'a> M: HasModel> + DbAccessMut + DbAccessMut + DbAccessByKey = &'a AcmeProvider> + Send + Sync + 'static, A: Accept + 'static, ::Metadata: Visit> + Visit> + Clone + Send + Sync + 'static, { type Metadata = VHostListenerMetadata; fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll> { let (metadata, stream) = ready!(self.0.poll_accept(cx)?); let preprocessed = self.0.tls_handler.wrapper.1.take(); Poll::Ready(Ok(( VHostListenerMetadata { inner: metadata, preprocessed: preprocessed.ok_or_else(|| { Error::new( eyre!("tlslistener yielded but preprocessed isn't set"), ErrorKind::Incoherent, ) })?, }, stream, ))) } } impl VHostListener where for<'a> M: HasModel> + DbAccessMut + DbAccessMut + DbAccessByKey = &'a AcmeProvider> + Send + Sync + 'static, A: Accept + 'static, ::Metadata: Visit> + Visit> + Clone + Send + Sync + 'static, { async fn handle_next(&mut self) -> Result<(), Error> { let (metadata, stream) = futures::future::poll_fn(|cx| self.poll_accept(cx)).await?; metadata.preprocessed.finish(stream, metadata.inner); Ok(()) } } struct VHostServer { mapping: Watch>, _thread: NonDetachingJoinHandle<()>, } impl<'a> From<&'a BTreeMap, BTreeMap>>> for AnyFilter { fn from(value: &'a BTreeMap, BTreeMap>>) -> Self { Self( value .iter() .flat_map(|(_, v)| { v.iter() .filter(|(_, r)| r.strong_count() > 0) .map(|(t, _)| t.filter.clone()) }) .collect(), ) } } impl VHostServer { #[instrument(skip_all)] fn new( listener: A, db: TypedPatchDb, crypto_provider: Arc, acme_cache: AcmeTlsAlpnCache, ) -> Self where for<'a> M: HasModel> + DbAccessMut + DbAccessMut + DbAccessByKey = &'a AcmeProvider> + Send + Sync + 'static, A: Accept + Send + 'static, ::Metadata: Visit> + Visit> + Clone + Send + Sync + 'static, { let mapping = Watch::new(BTreeMap::new()); Self { mapping: mapping.clone(), _thread: tokio::spawn(async move { let mut listener = VHostListener(TlsListener::new( listener, TlsHandlerWrapper { inner: ChainedHandler( Arc::new(AcmeTlsHandler { db: db.clone(), acme_cache, crypto_provider: crypto_provider.clone(), get_provider: GetVHostAcmeProvider(mapping.clone()), in_progress: Watch::new(BTreeSet::new()), }), RootCaTlsHandler { db, crypto_provider, }, ), wrapper: VHostConnector(mapping, None), }, )); loop { if let Err(e) = listener.handle_next().await { tracing::trace!("VHostServer: failed to accept connection: {e}"); tracing::trace!("{e:?}"); } } }) .into(), } } fn add( &self, hostname: Option, target: DynVHostTarget, ) -> Result, Error> { let target = target.into(); let mut res = Ok(Arc::new(())); self.mapping.send_if_modified(|writable| { let mut changed = false; let mut targets = writable.remove(&hostname).unwrap_or_default(); let rc = if let Some(rc) = Weak::upgrade(&targets.remove(&target).unwrap_or_default()) { rc } else { changed = true; Arc::new(()) }; targets.retain(|_, rc| rc.strong_count() > 0); targets.insert(target, Arc::downgrade(&rc)); writable.insert(hostname, targets); res = Ok(rc); changed }); if self.mapping.watcher_count() > 1 { res } else { Err(Error::new( eyre!("VHost Service Thread has exited"), crate::ErrorKind::Network, )) } } fn gc(&self, hostname: Option) { self.mapping.send_if_modified(|writable| { let mut targets = writable.remove(&hostname).unwrap_or_default(); let pre = targets.len(); targets = targets .into_iter() .filter(|(_, rc)| rc.strong_count() > 0) .collect(); let post = targets.len(); if !targets.is_empty() { writable.insert(hostname, targets); } pre == post }); } fn is_empty(&self) -> bool { self.mapping.peek(|m| m.is_empty()) } }