Files
start-os/core/src/net/vhost.rs
Matt Hill 8b89f016ad task fix and keyboard fix (#3130)
* task fix and keyboard fix

* fixes for build scripts

* passthrough feature

* feat: inline domain health checks and improve address UX

- addPublicDomain returns DNS query + port check results (AddPublicDomainRes)
  so frontend skips separate API calls after adding a domain
- addPrivateDomain returns check_dns result for the gateway
- Support multiple ports per domain in validation modal (deduplicated)
- Run port checks concurrently via futures::future::join_all
- Add note to add-domain dialog showing other interfaces on same host
- Add addXForwardedHeaders to knownProtocols in SDK Host.ts
- Add plugin filter kind, pluginId filter, matchesAny, and docs to
  getServiceInterface.ts
- Add PassthroughInfo type and passthroughs field to NetworkInfo
- Pluralize "port forwarding rules" in i18n dictionaries

* feat: add shared host note to private domain dialog with i18n

* fix: scope public domain to single binding and return single port check

Accept internalPort in AddPublicDomainParams to target a specific
binding. Disable the domain on all other bindings. Return a single
CheckPortRes instead of Vec. Revert multi-port UI to singular port
display from 0f8a66b35.

* better shared hostname approach,  and improve look-feel of addresses tables

* fix starttls

* preserve usb as top efi boot option

* fix race condition in wan ip check

* sdk beta.56

* various bug, improve smtp

* multiple bugs, better outbound gateway UX

* remove non option from smtp for better package compat

* bump sdk

---------

Co-authored-by: Aiden McClelland <me@drbonez.dev>
2026-03-06 00:30:06 -07:00

1177 lines
38 KiB
Rust

use std::any::Any;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt;
use std::net::{IpAddr, SocketAddr, SocketAddrV6};
use std::sync::{Arc, Weak};
use std::task::{Poll, ready};
use async_acme::acme::ACME_TLS_ALPN_NAME;
use clap::Parser;
use color_eyre::eyre::eyre;
use futures::FutureExt;
use futures::future::BoxFuture;
use imbl::OrdMap;
use imbl_value::{InOMap, InternedString};
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn, from_fn_async};
use serde::{Deserialize, Serialize};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::TlsConnector;
use tokio_rustls::rustls::crypto::CryptoProvider;
use tokio_rustls::rustls::pki_types::ServerName;
use tokio_rustls::rustls::server::ClientHello;
use tokio_rustls::rustls::{ClientConfig, ServerConfig};
use tracing::instrument;
use ts_rs::TS;
use visit_rs::Visit;
use crate::context::{CliContext, RpcContext};
use crate::db::model::Database;
use crate::db::model::public::{AcmeSettings, NetworkInterfaceInfo};
use crate::db::{DbAccessByKey, DbAccessMut};
use crate::net::acme::{
AcmeCertStore, AcmeProvider, AcmeTlsAlpnCache, AcmeTlsHandler, GetAcmeProvider,
};
use crate::net::gateway::{
GatewayInfo, NetworkInterfaceController, NetworkInterfaceListenerAcceptMetadata,
};
use crate::net::ssl::{CertStore, RootCaTlsHandler};
use crate::net::tls::{
ChainedHandler, TlsHandlerAction, TlsHandlerWrapper, TlsListener, TlsMetadata, WrapTlsHandler,
};
use crate::net::utils::ipv6_is_link_local;
use crate::net::web_server::{Accept, AcceptStream, ExtractVisitor, TcpMetadata, extract};
use crate::prelude::*;
use crate::util::collections::EqSet;
use crate::util::future::{NonDetachingJoinHandle, WeakFuture};
use crate::util::serde::{HandlerExtSerde, MaybeUtf8String, display_serializable};
use crate::util::sync::{SyncMutex, Watch};
use crate::{GatewayId, ResultExt};
#[derive(Debug, Clone, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
#[ts(export)]
pub struct PassthroughInfo {
#[ts(type = "string")]
pub hostname: InternedString,
pub listen_port: u16,
#[ts(type = "string")]
pub backend: SocketAddr,
#[ts(type = "string[]")]
pub public_gateways: BTreeSet<GatewayId>,
#[ts(type = "string[]")]
pub private_ips: BTreeSet<IpAddr>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
#[serde(rename_all = "kebab-case")]
struct AddPassthroughParams {
#[arg(long)]
pub hostname: InternedString,
#[arg(long)]
pub listen_port: u16,
#[arg(long)]
pub backend: SocketAddr,
#[arg(long)]
pub public_gateway: Vec<GatewayId>,
#[arg(long)]
pub private_ip: Vec<IpAddr>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
#[serde(rename_all = "kebab-case")]
struct RemovePassthroughParams {
#[arg(long)]
pub hostname: InternedString,
#[arg(long)]
pub listen_port: u16,
}
pub fn vhost_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand(
"dump-table",
from_fn(dump_table)
.with_display_serializable()
.with_custom_display_fn(|HandlerArgs { params, .. }, res| {
use prettytable::*;
if let Some(format) = params.format {
display_serializable(format, res)?;
return Ok::<_, Error>(());
}
let mut table = Table::new();
table.add_row(row![bc => "FROM", "TO", "ACTIVE"]);
for (external, targets) in res {
for (host, targets) in targets {
for (idx, target) in targets.into_iter().enumerate() {
table.add_row(row![
format!(
"{}:{}",
host.as_ref().map(|s| &**s).unwrap_or("*"),
external.0
),
target,
idx == 0
]);
}
}
}
table.print_tty(false)?;
Ok(())
})
.with_call_remote::<CliContext>(),
)
.subcommand(
"add-passthrough",
from_fn_async(add_passthrough)
.no_display()
.with_call_remote::<CliContext>(),
)
.subcommand(
"remove-passthrough",
from_fn_async(remove_passthrough)
.no_display()
.with_call_remote::<CliContext>(),
)
.subcommand(
"list-passthrough",
from_fn(list_passthrough)
.with_display_serializable()
.with_call_remote::<CliContext>(),
)
}
fn dump_table(
ctx: RpcContext,
) -> Result<BTreeMap<JsonKey<u16>, BTreeMap<JsonKey<Option<InternedString>>, EqSet<String>>>, Error>
{
Ok(ctx.net_controller.vhost.dump_table())
}
async fn add_passthrough(
ctx: RpcContext,
AddPassthroughParams {
hostname,
listen_port,
backend,
public_gateway,
private_ip,
}: AddPassthroughParams,
) -> Result<(), Error> {
let public_gateways: BTreeSet<GatewayId> = public_gateway.into_iter().collect();
let private_ips: BTreeSet<IpAddr> = private_ip.into_iter().collect();
ctx.net_controller.vhost.add_passthrough(
hostname.clone(),
listen_port,
backend,
public_gateways.clone(),
private_ips.clone(),
)?;
ctx.db
.mutate(|db| {
let pts = db
.as_public_mut()
.as_server_info_mut()
.as_network_mut()
.as_passthroughs_mut();
let mut vec: Vec<PassthroughInfo> = pts.de()?;
vec.retain(|p| !(p.hostname == hostname && p.listen_port == listen_port));
vec.push(PassthroughInfo {
hostname,
listen_port,
backend,
public_gateways,
private_ips,
});
pts.ser(&vec)
})
.await
.result?;
Ok(())
}
async fn remove_passthrough(
ctx: RpcContext,
RemovePassthroughParams {
hostname,
listen_port,
}: RemovePassthroughParams,
) -> Result<(), Error> {
ctx.net_controller
.vhost
.remove_passthrough(&hostname, listen_port);
ctx.db
.mutate(|db| {
let pts = db
.as_public_mut()
.as_server_info_mut()
.as_network_mut()
.as_passthroughs_mut();
let mut vec: Vec<PassthroughInfo> = pts.de()?;
vec.retain(|p| !(p.hostname == hostname && p.listen_port == listen_port));
pts.ser(&vec)
})
.await
.result?;
Ok(())
}
fn list_passthrough(ctx: RpcContext) -> Result<Vec<PassthroughInfo>, Error> {
Ok(ctx.net_controller.vhost.list_passthrough())
}
// not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353
struct PassthroughHandle {
_rc: Arc<()>,
backend: SocketAddr,
public: BTreeSet<GatewayId>,
private: BTreeSet<IpAddr>,
}
pub struct VHostController {
db: TypedPatchDb<Database>,
interfaces: Arc<NetworkInterfaceController>,
crypto_provider: Arc<CryptoProvider>,
acme_cache: AcmeTlsAlpnCache,
servers: SyncMutex<BTreeMap<u16, VHostServer<VHostBindListener>>>,
passthrough_handles: SyncMutex<BTreeMap<(InternedString, u16), PassthroughHandle>>,
}
impl VHostController {
pub fn new(
db: TypedPatchDb<Database>,
interfaces: Arc<NetworkInterfaceController>,
crypto_provider: Arc<CryptoProvider>,
passthroughs: Vec<PassthroughInfo>,
) -> Self {
let controller = Self {
db,
interfaces,
crypto_provider,
acme_cache: Arc::new(SyncMutex::new(BTreeMap::new())),
servers: SyncMutex::new(BTreeMap::new()),
passthrough_handles: SyncMutex::new(BTreeMap::new()),
};
for pt in passthroughs {
if let Err(e) = controller.add_passthrough(
pt.hostname,
pt.listen_port,
pt.backend,
pt.public_gateways,
pt.private_ips,
) {
tracing::warn!("failed to restore passthrough: {e}");
}
}
controller
}
#[instrument(skip_all)]
pub fn add(
&self,
hostname: Option<InternedString>,
external: u16,
target: DynVHostTarget<VHostBindListener>,
) -> Result<Arc<()>, Error> {
self.servers.mutate(|writable| {
let server = if let Some(server) = writable.remove(&external) {
server
} else {
self.create_server(external)
};
let rc = server.add(hostname, target);
writable.insert(external, server);
Ok(rc?)
})
}
fn create_server(&self, port: u16) -> VHostServer<VHostBindListener> {
let bind_reqs = Watch::new(VHostBindRequirements::default());
let listener = VHostBindListener {
ip_info: self.interfaces.watcher.subscribe(),
port,
bind_reqs: bind_reqs.clone_unseen(),
listeners: BTreeMap::new(),
};
VHostServer::new(
listener,
bind_reqs,
self.db.clone(),
self.crypto_provider.clone(),
self.acme_cache.clone(),
)
}
pub fn add_passthrough(
&self,
hostname: InternedString,
port: u16,
backend: SocketAddr,
public: BTreeSet<GatewayId>,
private: BTreeSet<IpAddr>,
) -> Result<(), Error> {
let target = ProxyTarget {
public: public.clone(),
private: private.clone(),
acme: None,
addr: backend,
add_x_forwarded_headers: false,
connect_ssl: Err(AlpnInfo::Reflect),
passthrough: true,
};
let rc = self.add(Some(hostname.clone()), port, DynVHostTarget::new(target))?;
self.passthrough_handles.mutate(|h| {
h.insert(
(hostname, port),
PassthroughHandle {
_rc: rc,
backend,
public,
private,
},
);
});
Ok(())
}
pub fn remove_passthrough(&self, hostname: &InternedString, port: u16) {
self.passthrough_handles
.mutate(|h| h.remove(&(hostname.clone(), port)));
self.gc(Some(hostname.clone()), port);
}
pub fn list_passthrough(&self) -> Vec<PassthroughInfo> {
self.passthrough_handles.peek(|h| {
h.iter()
.map(|((hostname, port), handle)| PassthroughInfo {
hostname: hostname.clone(),
listen_port: *port,
backend: handle.backend,
public_gateways: handle.public.clone(),
private_ips: handle.private.clone(),
})
.collect()
})
}
pub fn dump_table(
&self,
) -> BTreeMap<JsonKey<u16>, BTreeMap<JsonKey<Option<InternedString>>, EqSet<String>>> {
self.servers.peek(|s| {
s.iter()
.map(|(k, v)| {
(
JsonKey::new(*k),
v.mapping.peek(|m| {
m.iter()
.map(|(k, v)| {
(
JsonKey::new(k.clone()),
v.iter()
.filter(|(_, v)| v.strong_count() > 0)
.map(|(k, _)| format!("{k:#?}"))
.collect(),
)
})
.collect()
}),
)
})
.collect()
})
}
#[instrument(skip_all)]
pub fn gc(&self, hostname: Option<InternedString>, external: u16) {
self.servers.mutate(|writable| {
if let Some(server) = writable.remove(&external) {
server.gc(hostname);
if !server.is_empty() {
writable.insert(external, server);
}
}
})
}
}
/// Union of all ProxyTargets' bind requirements for a VHostServer.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct VHostBindRequirements {
pub public_gateways: BTreeSet<GatewayId>,
pub private_ips: BTreeSet<IpAddr>,
}
fn compute_bind_reqs<A: Accept + 'static>(mapping: &Mapping<A>) -> VHostBindRequirements {
let mut reqs = VHostBindRequirements::default();
for (_, targets) in mapping {
for (target, rc) in targets {
if rc.strong_count() > 0 {
let (pub_gw, priv_ip) = target.0.bind_requirements();
reqs.public_gateways.extend(pub_gw);
reqs.private_ips.extend(priv_ip);
}
}
}
reqs
}
/// Listener that manages its own TCP listeners with IP-level precision.
/// Binds ALL IPs of public gateways and ONLY matching private IPs.
pub struct VHostBindListener {
ip_info: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
port: u16,
bind_reqs: Watch<VHostBindRequirements>,
listeners: BTreeMap<SocketAddr, (TcpListener, GatewayInfo)>,
}
fn update_vhost_listeners(
listeners: &mut BTreeMap<SocketAddr, (TcpListener, GatewayInfo)>,
port: u16,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
reqs: &VHostBindRequirements,
) -> Result<(), Error> {
let mut keep = BTreeSet::<SocketAddr>::new();
for (gw_id, info) in ip_info {
if let Some(ip_info) = &info.ip_info {
for ipnet in &ip_info.subnets {
let ip = ipnet.addr();
let should_bind =
reqs.public_gateways.contains(gw_id) || reqs.private_ips.contains(&ip);
if should_bind {
let addr = match ip {
IpAddr::V6(ip6) => SocketAddrV6::new(
ip6,
port,
0,
if ipv6_is_link_local(ip6) {
ip_info.scope_id
} else {
0
},
)
.into(),
ip => SocketAddr::new(ip, port),
};
keep.insert(addr);
if let Some((_, existing_info)) = listeners.get_mut(&addr) {
*existing_info = GatewayInfo {
id: gw_id.clone(),
info: info.clone(),
};
} else {
let tcp = TcpListener::from_std(
mio::net::TcpListener::bind(addr)
.with_kind(ErrorKind::Network)?
.into(),
)
.with_kind(ErrorKind::Network)?;
listeners.insert(
addr,
(
tcp,
GatewayInfo {
id: gw_id.clone(),
info: info.clone(),
},
),
);
}
}
}
}
}
listeners.retain(|key, _| keep.contains(key));
Ok(())
}
impl Accept for VHostBindListener {
type Metadata = NetworkInterfaceListenerAcceptMetadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
// Update listeners when ip_info or bind_reqs change
while self.ip_info.poll_changed(cx).is_ready() || self.bind_reqs.poll_changed(cx).is_ready()
{
let reqs = self.bind_reqs.read_and_mark_seen();
let listeners = &mut self.listeners;
let port = self.port;
self.ip_info.peek_and_mark_seen(|ip_info| {
update_vhost_listeners(listeners, port, ip_info, &reqs)
})?;
}
// Poll each listener for incoming connections
for (&addr, (listener, gw_info)) in &self.listeners {
match listener.poll_accept(cx) {
Poll::Ready(Ok((stream, peer_addr))) => {
if let Err(e) = socket2::SockRef::from(&stream).set_keepalive(true) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
return Poll::Ready(Ok((
NetworkInterfaceListenerAcceptMetadata {
inner: TcpMetadata {
local_addr: addr,
peer_addr,
},
info: gw_info.clone(),
},
Box::pin(stream),
)));
}
Poll::Ready(Err(e)) => {
tracing::trace!("VHostBindListener accept error on {addr}: {e}");
}
Poll::Pending => {}
}
}
Poll::Pending
}
}
pub trait VHostTarget<A: Accept>: std::fmt::Debug + Eq {
type PreprocessRes: Send + 'static;
#[allow(unused_variables)]
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool {
true
}
fn acme(&self) -> Option<&AcmeProvider> {
None
}
/// Returns (public_gateways, private_ips) this target needs the listener to bind on.
fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>) {
(BTreeSet::new(), BTreeSet::new())
}
fn is_passthrough(&self) -> bool {
false
}
fn preprocess<'a>(
&'a self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> impl Future<Output = Option<(ServerConfig, Self::PreprocessRes)>> + Send + 'a;
fn handle_stream(
&self,
stream: AcceptStream,
metadata: TlsMetadata<<A as Accept>::Metadata>,
prev: Self::PreprocessRes,
rc: Weak<()>,
);
}
pub trait DynVHostTargetT<A: Accept>: std::fmt::Debug + Any {
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool;
fn acme(&self) -> Option<&AcmeProvider>;
fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>);
fn is_passthrough(&self) -> bool;
fn preprocess<'a>(
&'a self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> BoxFuture<'a, Option<(ServerConfig, Box<dyn Any + Send>)>>
where
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>;
fn handle_stream(
&self,
stream: AcceptStream,
metadata: TlsMetadata<<A as Accept>::Metadata>,
prev: Box<dyn Any + Send>,
rc: Weak<()>,
);
fn eq(&self, other: &dyn DynVHostTargetT<A>) -> bool;
}
impl<A: Accept, T: VHostTarget<A> + 'static> DynVHostTargetT<A> for T {
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool {
VHostTarget::filter(self, metadata)
}
fn acme(&self) -> Option<&AcmeProvider> {
VHostTarget::acme(self)
}
fn is_passthrough(&self) -> bool {
VHostTarget::is_passthrough(self)
}
fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>) {
VHostTarget::bind_requirements(self)
}
fn preprocess<'a>(
&'a self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> BoxFuture<'a, Option<(ServerConfig, Box<dyn Any + Send>)>> {
VHostTarget::preprocess(self, prev, hello, metadata)
.map(|o| o.map(|(cfg, res)| (cfg, Box::new(res) as Box<dyn Any + Send>)))
.boxed()
}
fn handle_stream(
&self,
stream: AcceptStream,
metadata: TlsMetadata<<A as Accept>::Metadata>,
prev: Box<dyn Any + Send>,
rc: Weak<()>,
) {
if let Ok(prev) = prev.downcast() {
VHostTarget::handle_stream(self, stream, metadata, *prev, rc);
}
}
fn eq(&self, other: &dyn DynVHostTargetT<A>) -> bool {
Some(self) == (other as &dyn Any).downcast_ref()
}
}
pub struct DynVHostTarget<A: Accept>(Arc<dyn DynVHostTargetT<A> + Send + Sync>);
impl<A: Accept> DynVHostTarget<A> {
pub fn new<T: VHostTarget<A> + Send + Sync + 'static>(target: T) -> Self {
Self(Arc::new(target))
}
}
impl<A: Accept> Clone for DynVHostTarget<A> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<A: Accept> std::fmt::Debug for DynVHostTarget<A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl<A: Accept + 'static> PartialEq for DynVHostTarget<A> {
fn eq(&self, other: &Self) -> bool {
self.0.eq(&*other.0)
}
}
impl<A: Accept + 'static> Eq for DynVHostTarget<A> {}
struct Preprocessed<A: Accept>(DynVHostTarget<A>, Weak<()>, Box<dyn Any + Send>);
impl<A: Accept> fmt::Debug for Preprocessed<A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(self.0).0.fmt(f)
}
}
impl<A: Accept + 'static> DynVHostTarget<A> {
async fn into_preprocessed(
self,
rc: Weak<()>,
prev: ServerConfig,
hello: &ClientHello<'_>,
metadata: &<A as Accept>::Metadata,
) -> Option<(ServerConfig, Preprocessed<A>)>
where
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>,
{
let (cfg, res) = self.0.preprocess(prev, hello, metadata).await?;
Some((cfg, Preprocessed(self, rc, res)))
}
}
impl<A: Accept + 'static> Preprocessed<A> {
fn finish(self, stream: AcceptStream, metadata: TlsMetadata<<A as Accept>::Metadata>) {
(self.0).0.handle_stream(stream, metadata, self.2, self.1);
}
}
#[derive(Clone)]
pub struct ProxyTarget {
pub public: BTreeSet<GatewayId>,
pub private: BTreeSet<IpAddr>,
pub acme: Option<AcmeProvider>,
pub addr: SocketAddr,
pub add_x_forwarded_headers: bool,
pub connect_ssl: Result<Arc<ClientConfig>, AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn
pub passthrough: bool,
}
impl PartialEq for ProxyTarget {
fn eq(&self, other: &Self) -> bool {
self.public == other.public
&& self.private == other.private
&& self.acme == other.acme
&& self.addr == other.addr
&& self.passthrough == other.passthrough
&& self.connect_ssl.as_ref().map(Arc::as_ptr)
== other.connect_ssl.as_ref().map(Arc::as_ptr)
}
}
impl Eq for ProxyTarget {}
impl fmt::Debug for ProxyTarget {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ProxyTarget")
.field("public", &self.public)
.field("private", &self.private)
.field("acme", &self.acme)
.field("addr", &self.addr)
.field("add_x_forwarded_headers", &self.add_x_forwarded_headers)
.field("connect_ssl", &self.connect_ssl.as_ref().map(|_| ()))
.field("passthrough", &self.passthrough)
.finish()
}
}
impl<A> VHostTarget<A> for ProxyTarget
where
A: Accept + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<GatewayInfo>>
+ Visit<ExtractVisitor<TcpMetadata>>
+ Clone
+ Send
+ Sync,
{
type PreprocessRes = AcceptStream;
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool {
let gw = extract::<GatewayInfo, _>(metadata);
let tcp = extract::<TcpMetadata, _>(metadata);
let (Some(gw), Some(tcp)) = (gw, tcp) else {
return false;
};
let Some(ip_info) = &gw.info.ip_info else {
return false;
};
let src = tcp.peer_addr.ip();
// Public: source is outside all known subnets (direct internet)
let is_public = !ip_info.subnets.iter().any(|s| s.contains(&src));
if is_public {
self.public.contains(&gw.id)
} else {
// Private: accept if connection arrived on an interface with a matching IP
ip_info
.subnets
.iter()
.any(|s| self.private.contains(&s.addr()))
}
}
fn acme(&self) -> Option<&AcmeProvider> {
self.acme.as_ref()
}
fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>) {
(self.public.clone(), self.private.clone())
}
fn is_passthrough(&self) -> bool {
self.passthrough
}
async fn preprocess<'a>(
&'a self,
mut prev: ServerConfig,
hello: &'a ClientHello<'a>,
_: &'a <A as Accept>::Metadata,
) -> Option<(ServerConfig, Self::PreprocessRes)> {
let tcp_stream = TcpStream::connect(self.addr)
.await
.with_ctx(|_| (ErrorKind::Network, self.addr))
.log_err()?;
if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
match &self.connect_ssl {
Ok(client_cfg) => {
let mut client_cfg = (&**client_cfg).clone();
client_cfg.alpn_protocols = hello
.alpn()
.into_iter()
.flatten()
.map(|x| x.to_vec())
.collect();
let target_stream = TlsConnector::from(Arc::new(client_cfg))
.connect_with(
ServerName::IpAddress(self.addr.ip().into()),
tcp_stream,
|conn| {
prev.alpn_protocols
.extend(conn.alpn_protocol().into_iter().map(|p| p.to_vec()))
},
)
.await
.log_err()?;
return Some((prev, Box::pin(target_stream)));
}
Err(AlpnInfo::Reflect) => {
for alpn in hello.alpn().into_iter().flatten() {
prev.alpn_protocols.push(alpn.to_vec());
}
}
Err(AlpnInfo::Specified(a)) => {
for alpn in a {
prev.alpn_protocols.push(alpn.0.clone());
}
}
}
Some((prev, Box::pin(tcp_stream)))
}
fn handle_stream(
&self,
mut stream: AcceptStream,
metadata: TlsMetadata<<A as Accept>::Metadata>,
mut prev: Self::PreprocessRes,
rc: Weak<()>,
) {
let add_x_forwarded_headers = self.add_x_forwarded_headers;
tokio::spawn(async move {
WeakFuture::new(rc, async move {
if add_x_forwarded_headers {
crate::net::http::run_http_proxy(
stream,
prev,
metadata.tls_info.alpn,
extract::<TcpMetadata, _>(&metadata.inner).map(|m| m.peer_addr.ip()),
)
.await
.ok();
} else {
tokio::io::copy_bidirectional(&mut stream, &mut prev)
.await
.ok();
}
})
.await
});
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub enum AlpnInfo {
Reflect,
Specified(Vec<MaybeUtf8String>),
}
impl Default for AlpnInfo {
fn default() -> Self {
Self::Reflect
}
}
type Mapping<A> = BTreeMap<Option<InternedString>, InOMap<DynVHostTarget<A>, Weak<()>>>;
pub struct GetVHostAcmeProvider<A: Accept + 'static>(pub Watch<Mapping<A>>);
impl<A: Accept + 'static> Clone for GetVHostAcmeProvider<A> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<A: Accept + 'static> GetAcmeProvider for GetVHostAcmeProvider<A> {
async fn get_provider<'a, 'b: 'a>(
&'b self,
san_info: &'a BTreeSet<InternedString>,
) -> Option<impl AsRef<AcmeProvider> + Send + 'b> {
self.0.peek(|m| -> Option<AcmeProvider> {
san_info
.iter()
.fold(Some::<Option<&AcmeProvider>>(None), |acc, x| {
let acc = acc?;
if x.parse::<IpAddr>().is_ok() {
return Some(acc);
}
let (t, _) = m
.get(&Some(x.clone()))?
.iter()
.find(|(_, rc)| rc.strong_count() > 0)?;
let acme = t.0.acme()?;
Some(if let Some(acc) = acc {
if acme == acc {
// all must match
Some(acme)
} else {
None
}
} else {
Some(acme)
})
})
.flatten()
.cloned()
})
}
}
pub struct VHostConnector<A: Accept + 'static>(Watch<Mapping<A>>, Option<Preprocessed<A>>);
impl<A: Accept + 'static> Clone for VHostConnector<A> {
fn clone(&self) -> Self {
Self(self.0.clone(), None)
}
}
impl<A> WrapTlsHandler<A> for VHostConnector<A>
where
A: Accept + 'static,
<A as Accept>::Metadata:
Visit<ExtractVisitor<GatewayInfo>> + Visit<ExtractVisitor<TcpMetadata>> + Send + Sync,
{
async fn wrap<'a>(
&'a mut self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> Option<TlsHandlerAction>
where
Self: 'a,
{
if hello
.alpn()
.into_iter()
.flatten()
.any(|a| a == ACME_TLS_ALPN_NAME)
{
return Some(TlsHandlerAction::Tls(prev));
}
let (target, rc) = self.0.peek(|m| {
m.get(&hello.server_name().map(InternedString::from))
.or_else(|| m.get(&None))
.into_iter()
.flatten()
.filter(|(_, rc)| rc.strong_count() > 0)
.find(|(t, _)| t.0.filter(metadata))
.map(|(t, rc)| (t.clone(), rc.clone()))
})?;
let is_pt = target.0.is_passthrough();
let (prev, store) = target.into_preprocessed(rc, prev, hello, metadata).await?;
self.1 = Some(store);
if is_pt {
Some(TlsHandlerAction::Passthrough)
} else {
Some(TlsHandlerAction::Tls(prev))
}
}
}
struct VHostListener<M, A>(
TlsListener<
A,
TlsHandlerWrapper<
ChainedHandler<Arc<AcmeTlsHandler<M, GetVHostAcmeProvider<A>>>, RootCaTlsHandler<M>>,
VHostConnector<A>,
>,
>,
)
where
for<'a> M: HasModel<Model = Model<M>>
+ DbAccessMut<CertStore>
+ DbAccessMut<AcmeCertStore>
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ Send
+ Sync,
A: Accept + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
+ Clone
+ Send
+ Sync
+ 'static;
struct VHostListenerMetadata<A: Accept> {
inner: TlsMetadata<A::Metadata>,
preprocessed: Preprocessed<A>,
}
impl<A: Accept> fmt::Debug for VHostListenerMetadata<A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("VHostListenerMetadata")
.field("inner", &self.inner)
.field("preprocessed", &self.preprocessed)
.finish()
}
}
impl<M, A> Accept for VHostListener<M, A>
where
for<'a> M: HasModel<Model = Model<M>>
+ DbAccessMut<CertStore>
+ DbAccessMut<AcmeCertStore>
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ Send
+ Sync
+ 'static,
A: Accept + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
+ Clone
+ Send
+ Sync
+ 'static,
{
type Metadata = VHostListenerMetadata<A>;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
let (metadata, stream) = ready!(self.0.poll_accept(cx)?);
let preprocessed = self.0.tls_handler.wrapper.1.take();
Poll::Ready(Ok((
VHostListenerMetadata {
inner: metadata,
preprocessed: preprocessed.ok_or_else(|| {
Error::new(
eyre!("tlslistener yielded but preprocessed isn't set"),
ErrorKind::Incoherent,
)
})?,
},
stream,
)))
}
}
impl<M, A> VHostListener<M, A>
where
for<'a> M: HasModel<Model = Model<M>>
+ DbAccessMut<CertStore>
+ DbAccessMut<AcmeCertStore>
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ Send
+ Sync
+ 'static,
A: Accept + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
+ Clone
+ Send
+ Sync
+ 'static,
{
async fn handle_next(&mut self) -> Result<(), Error> {
let (metadata, stream) = futures::future::poll_fn(|cx| self.poll_accept(cx)).await?;
metadata.preprocessed.finish(stream, metadata.inner);
Ok(())
}
}
struct VHostServer<A: Accept + 'static> {
mapping: Watch<Mapping<A>>,
bind_reqs: Watch<VHostBindRequirements>,
_thread: NonDetachingJoinHandle<()>,
}
impl<A: Accept> VHostServer<A> {
#[instrument(skip_all)]
fn new<M: HasModel>(
listener: A,
bind_reqs: Watch<VHostBindRequirements>,
db: TypedPatchDb<M>,
crypto_provider: Arc<CryptoProvider>,
acme_cache: AcmeTlsAlpnCache,
) -> Self
where
for<'a> M: HasModel<Model = Model<M>>
+ DbAccessMut<CertStore>
+ DbAccessMut<AcmeCertStore>
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ Send
+ Sync
+ 'static,
A: Accept + Send + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
+ Clone
+ Send
+ Sync
+ 'static,
{
let mapping = Watch::new(BTreeMap::new());
Self {
mapping: mapping.clone(),
bind_reqs,
_thread: tokio::spawn(async move {
let mut listener = VHostListener(TlsListener::new(
listener,
TlsHandlerWrapper {
inner: ChainedHandler(
Arc::new(AcmeTlsHandler {
db: db.clone(),
acme_cache,
crypto_provider: crypto_provider.clone(),
get_provider: GetVHostAcmeProvider(mapping.clone()),
in_progress: Watch::new(BTreeSet::new()),
}),
RootCaTlsHandler {
db,
crypto_provider,
},
),
wrapper: VHostConnector(mapping, None),
},
));
loop {
if let Err(e) = listener.handle_next().await {
tracing::trace!("VHostServer: failed to accept connection: {e}");
tracing::trace!("{e:?}");
}
}
})
.into(),
}
}
fn add(
&self,
hostname: Option<InternedString>,
target: DynVHostTarget<A>,
) -> Result<Arc<()>, Error> {
let target = target.into();
let mut res = Ok(Arc::new(()));
self.mapping.send_if_modified(|writable| {
let mut changed = false;
let mut targets = writable.remove(&hostname).unwrap_or_default();
let rc = if let Some(rc) = Weak::upgrade(&targets.remove(&target).unwrap_or_default()) {
rc
} else {
changed = true;
Arc::new(())
};
targets.retain(|_, rc| rc.strong_count() > 0);
targets.insert(target, Arc::downgrade(&rc));
writable.insert(hostname, targets);
res = Ok(rc);
if changed {
self.update_bind_reqs(writable);
}
changed
});
if self.mapping.watcher_count() > 1 {
res
} else {
Err(Error::new(
eyre!("VHost Service Thread has exited"),
crate::ErrorKind::Network,
))
}
}
fn gc(&self, hostname: Option<InternedString>) {
self.mapping.send_if_modified(|writable| {
let mut targets = writable.remove(&hostname).unwrap_or_default();
let pre = targets.len();
targets = targets
.into_iter()
.filter(|(_, rc)| rc.strong_count() > 0)
.collect();
let post = targets.len();
if !targets.is_empty() {
writable.insert(hostname, targets);
}
if pre != post {
self.update_bind_reqs(writable);
}
pre == post
});
}
fn update_bind_reqs(&self, mapping: &Mapping<A>) {
let new_reqs = compute_bind_reqs(mapping);
self.bind_reqs.send_if_modified(|reqs| {
if *reqs != new_reqs {
*reqs = new_reqs;
true
} else {
false
}
});
}
fn is_empty(&self) -> bool {
self.mapping.peek(|m| m.is_empty())
}
}