mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-03-27 02:41:53 +00:00
* task fix and keyboard fix
* fixes for build scripts
* passthrough feature
* feat: inline domain health checks and improve address UX
- addPublicDomain returns DNS query + port check results (AddPublicDomainRes)
so frontend skips separate API calls after adding a domain
- addPrivateDomain returns check_dns result for the gateway
- Support multiple ports per domain in validation modal (deduplicated)
- Run port checks concurrently via futures::future::join_all
- Add note to add-domain dialog showing other interfaces on same host
- Add addXForwardedHeaders to knownProtocols in SDK Host.ts
- Add plugin filter kind, pluginId filter, matchesAny, and docs to
getServiceInterface.ts
- Add PassthroughInfo type and passthroughs field to NetworkInfo
- Pluralize "port forwarding rules" in i18n dictionaries
* feat: add shared host note to private domain dialog with i18n
* fix: scope public domain to single binding and return single port check
Accept internalPort in AddPublicDomainParams to target a specific
binding. Disable the domain on all other bindings. Return a single
CheckPortRes instead of Vec. Revert multi-port UI to singular port
display from 0f8a66b35.
* better shared hostname approach, and improve look-feel of addresses tables
* fix starttls
* preserve usb as top efi boot option
* fix race condition in wan ip check
* sdk beta.56
* various bug, improve smtp
* multiple bugs, better outbound gateway UX
* remove non option from smtp for better package compat
* bump sdk
---------
Co-authored-by: Aiden McClelland <me@drbonez.dev>
1177 lines
38 KiB
Rust
1177 lines
38 KiB
Rust
use std::any::Any;
|
|
use std::collections::{BTreeMap, BTreeSet};
|
|
use std::fmt;
|
|
use std::net::{IpAddr, SocketAddr, SocketAddrV6};
|
|
use std::sync::{Arc, Weak};
|
|
use std::task::{Poll, ready};
|
|
|
|
use async_acme::acme::ACME_TLS_ALPN_NAME;
|
|
use clap::Parser;
|
|
use color_eyre::eyre::eyre;
|
|
use futures::FutureExt;
|
|
use futures::future::BoxFuture;
|
|
use imbl::OrdMap;
|
|
use imbl_value::{InOMap, InternedString};
|
|
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn, from_fn_async};
|
|
use serde::{Deserialize, Serialize};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
use tokio_rustls::TlsConnector;
|
|
use tokio_rustls::rustls::crypto::CryptoProvider;
|
|
use tokio_rustls::rustls::pki_types::ServerName;
|
|
use tokio_rustls::rustls::server::ClientHello;
|
|
use tokio_rustls::rustls::{ClientConfig, ServerConfig};
|
|
use tracing::instrument;
|
|
use ts_rs::TS;
|
|
use visit_rs::Visit;
|
|
|
|
use crate::context::{CliContext, RpcContext};
|
|
use crate::db::model::Database;
|
|
use crate::db::model::public::{AcmeSettings, NetworkInterfaceInfo};
|
|
use crate::db::{DbAccessByKey, DbAccessMut};
|
|
use crate::net::acme::{
|
|
AcmeCertStore, AcmeProvider, AcmeTlsAlpnCache, AcmeTlsHandler, GetAcmeProvider,
|
|
};
|
|
use crate::net::gateway::{
|
|
GatewayInfo, NetworkInterfaceController, NetworkInterfaceListenerAcceptMetadata,
|
|
};
|
|
use crate::net::ssl::{CertStore, RootCaTlsHandler};
|
|
use crate::net::tls::{
|
|
ChainedHandler, TlsHandlerAction, TlsHandlerWrapper, TlsListener, TlsMetadata, WrapTlsHandler,
|
|
};
|
|
use crate::net::utils::ipv6_is_link_local;
|
|
use crate::net::web_server::{Accept, AcceptStream, ExtractVisitor, TcpMetadata, extract};
|
|
use crate::prelude::*;
|
|
use crate::util::collections::EqSet;
|
|
use crate::util::future::{NonDetachingJoinHandle, WeakFuture};
|
|
use crate::util::serde::{HandlerExtSerde, MaybeUtf8String, display_serializable};
|
|
use crate::util::sync::{SyncMutex, Watch};
|
|
use crate::{GatewayId, ResultExt};
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize, HasModel, TS)]
|
|
#[serde(rename_all = "camelCase")]
|
|
#[model = "Model<Self>"]
|
|
#[ts(export)]
|
|
pub struct PassthroughInfo {
|
|
#[ts(type = "string")]
|
|
pub hostname: InternedString,
|
|
pub listen_port: u16,
|
|
#[ts(type = "string")]
|
|
pub backend: SocketAddr,
|
|
#[ts(type = "string[]")]
|
|
pub public_gateways: BTreeSet<GatewayId>,
|
|
#[ts(type = "string[]")]
|
|
pub private_ips: BTreeSet<IpAddr>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
|
|
#[serde(rename_all = "kebab-case")]
|
|
struct AddPassthroughParams {
|
|
#[arg(long)]
|
|
pub hostname: InternedString,
|
|
#[arg(long)]
|
|
pub listen_port: u16,
|
|
#[arg(long)]
|
|
pub backend: SocketAddr,
|
|
#[arg(long)]
|
|
pub public_gateway: Vec<GatewayId>,
|
|
#[arg(long)]
|
|
pub private_ip: Vec<IpAddr>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
|
|
#[serde(rename_all = "kebab-case")]
|
|
struct RemovePassthroughParams {
|
|
#[arg(long)]
|
|
pub hostname: InternedString,
|
|
#[arg(long)]
|
|
pub listen_port: u16,
|
|
}
|
|
|
|
pub fn vhost_api<C: Context>() -> ParentHandler<C> {
|
|
ParentHandler::new()
|
|
.subcommand(
|
|
"dump-table",
|
|
from_fn(dump_table)
|
|
.with_display_serializable()
|
|
.with_custom_display_fn(|HandlerArgs { params, .. }, res| {
|
|
use prettytable::*;
|
|
|
|
if let Some(format) = params.format {
|
|
display_serializable(format, res)?;
|
|
return Ok::<_, Error>(());
|
|
}
|
|
|
|
let mut table = Table::new();
|
|
table.add_row(row![bc => "FROM", "TO", "ACTIVE"]);
|
|
|
|
for (external, targets) in res {
|
|
for (host, targets) in targets {
|
|
for (idx, target) in targets.into_iter().enumerate() {
|
|
table.add_row(row![
|
|
format!(
|
|
"{}:{}",
|
|
host.as_ref().map(|s| &**s).unwrap_or("*"),
|
|
external.0
|
|
),
|
|
target,
|
|
idx == 0
|
|
]);
|
|
}
|
|
}
|
|
}
|
|
|
|
table.print_tty(false)?;
|
|
|
|
Ok(())
|
|
})
|
|
.with_call_remote::<CliContext>(),
|
|
)
|
|
.subcommand(
|
|
"add-passthrough",
|
|
from_fn_async(add_passthrough)
|
|
.no_display()
|
|
.with_call_remote::<CliContext>(),
|
|
)
|
|
.subcommand(
|
|
"remove-passthrough",
|
|
from_fn_async(remove_passthrough)
|
|
.no_display()
|
|
.with_call_remote::<CliContext>(),
|
|
)
|
|
.subcommand(
|
|
"list-passthrough",
|
|
from_fn(list_passthrough)
|
|
.with_display_serializable()
|
|
.with_call_remote::<CliContext>(),
|
|
)
|
|
}
|
|
|
|
fn dump_table(
|
|
ctx: RpcContext,
|
|
) -> Result<BTreeMap<JsonKey<u16>, BTreeMap<JsonKey<Option<InternedString>>, EqSet<String>>>, Error>
|
|
{
|
|
Ok(ctx.net_controller.vhost.dump_table())
|
|
}
|
|
|
|
async fn add_passthrough(
|
|
ctx: RpcContext,
|
|
AddPassthroughParams {
|
|
hostname,
|
|
listen_port,
|
|
backend,
|
|
public_gateway,
|
|
private_ip,
|
|
}: AddPassthroughParams,
|
|
) -> Result<(), Error> {
|
|
let public_gateways: BTreeSet<GatewayId> = public_gateway.into_iter().collect();
|
|
let private_ips: BTreeSet<IpAddr> = private_ip.into_iter().collect();
|
|
ctx.net_controller.vhost.add_passthrough(
|
|
hostname.clone(),
|
|
listen_port,
|
|
backend,
|
|
public_gateways.clone(),
|
|
private_ips.clone(),
|
|
)?;
|
|
ctx.db
|
|
.mutate(|db| {
|
|
let pts = db
|
|
.as_public_mut()
|
|
.as_server_info_mut()
|
|
.as_network_mut()
|
|
.as_passthroughs_mut();
|
|
let mut vec: Vec<PassthroughInfo> = pts.de()?;
|
|
vec.retain(|p| !(p.hostname == hostname && p.listen_port == listen_port));
|
|
vec.push(PassthroughInfo {
|
|
hostname,
|
|
listen_port,
|
|
backend,
|
|
public_gateways,
|
|
private_ips,
|
|
});
|
|
pts.ser(&vec)
|
|
})
|
|
.await
|
|
.result?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn remove_passthrough(
|
|
ctx: RpcContext,
|
|
RemovePassthroughParams {
|
|
hostname,
|
|
listen_port,
|
|
}: RemovePassthroughParams,
|
|
) -> Result<(), Error> {
|
|
ctx.net_controller
|
|
.vhost
|
|
.remove_passthrough(&hostname, listen_port);
|
|
ctx.db
|
|
.mutate(|db| {
|
|
let pts = db
|
|
.as_public_mut()
|
|
.as_server_info_mut()
|
|
.as_network_mut()
|
|
.as_passthroughs_mut();
|
|
let mut vec: Vec<PassthroughInfo> = pts.de()?;
|
|
vec.retain(|p| !(p.hostname == hostname && p.listen_port == listen_port));
|
|
pts.ser(&vec)
|
|
})
|
|
.await
|
|
.result?;
|
|
Ok(())
|
|
}
|
|
|
|
fn list_passthrough(ctx: RpcContext) -> Result<Vec<PassthroughInfo>, Error> {
|
|
Ok(ctx.net_controller.vhost.list_passthrough())
|
|
}
|
|
|
|
// not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353
|
|
|
|
struct PassthroughHandle {
|
|
_rc: Arc<()>,
|
|
backend: SocketAddr,
|
|
public: BTreeSet<GatewayId>,
|
|
private: BTreeSet<IpAddr>,
|
|
}
|
|
|
|
pub struct VHostController {
|
|
db: TypedPatchDb<Database>,
|
|
interfaces: Arc<NetworkInterfaceController>,
|
|
crypto_provider: Arc<CryptoProvider>,
|
|
acme_cache: AcmeTlsAlpnCache,
|
|
servers: SyncMutex<BTreeMap<u16, VHostServer<VHostBindListener>>>,
|
|
passthrough_handles: SyncMutex<BTreeMap<(InternedString, u16), PassthroughHandle>>,
|
|
}
|
|
impl VHostController {
|
|
pub fn new(
|
|
db: TypedPatchDb<Database>,
|
|
interfaces: Arc<NetworkInterfaceController>,
|
|
crypto_provider: Arc<CryptoProvider>,
|
|
passthroughs: Vec<PassthroughInfo>,
|
|
) -> Self {
|
|
let controller = Self {
|
|
db,
|
|
interfaces,
|
|
crypto_provider,
|
|
acme_cache: Arc::new(SyncMutex::new(BTreeMap::new())),
|
|
servers: SyncMutex::new(BTreeMap::new()),
|
|
passthrough_handles: SyncMutex::new(BTreeMap::new()),
|
|
};
|
|
for pt in passthroughs {
|
|
if let Err(e) = controller.add_passthrough(
|
|
pt.hostname,
|
|
pt.listen_port,
|
|
pt.backend,
|
|
pt.public_gateways,
|
|
pt.private_ips,
|
|
) {
|
|
tracing::warn!("failed to restore passthrough: {e}");
|
|
}
|
|
}
|
|
controller
|
|
}
|
|
#[instrument(skip_all)]
|
|
pub fn add(
|
|
&self,
|
|
hostname: Option<InternedString>,
|
|
external: u16,
|
|
target: DynVHostTarget<VHostBindListener>,
|
|
) -> Result<Arc<()>, Error> {
|
|
self.servers.mutate(|writable| {
|
|
let server = if let Some(server) = writable.remove(&external) {
|
|
server
|
|
} else {
|
|
self.create_server(external)
|
|
};
|
|
let rc = server.add(hostname, target);
|
|
writable.insert(external, server);
|
|
Ok(rc?)
|
|
})
|
|
}
|
|
|
|
fn create_server(&self, port: u16) -> VHostServer<VHostBindListener> {
|
|
let bind_reqs = Watch::new(VHostBindRequirements::default());
|
|
let listener = VHostBindListener {
|
|
ip_info: self.interfaces.watcher.subscribe(),
|
|
port,
|
|
bind_reqs: bind_reqs.clone_unseen(),
|
|
listeners: BTreeMap::new(),
|
|
};
|
|
VHostServer::new(
|
|
listener,
|
|
bind_reqs,
|
|
self.db.clone(),
|
|
self.crypto_provider.clone(),
|
|
self.acme_cache.clone(),
|
|
)
|
|
}
|
|
|
|
pub fn add_passthrough(
|
|
&self,
|
|
hostname: InternedString,
|
|
port: u16,
|
|
backend: SocketAddr,
|
|
public: BTreeSet<GatewayId>,
|
|
private: BTreeSet<IpAddr>,
|
|
) -> Result<(), Error> {
|
|
let target = ProxyTarget {
|
|
public: public.clone(),
|
|
private: private.clone(),
|
|
acme: None,
|
|
addr: backend,
|
|
add_x_forwarded_headers: false,
|
|
connect_ssl: Err(AlpnInfo::Reflect),
|
|
passthrough: true,
|
|
};
|
|
let rc = self.add(Some(hostname.clone()), port, DynVHostTarget::new(target))?;
|
|
self.passthrough_handles.mutate(|h| {
|
|
h.insert(
|
|
(hostname, port),
|
|
PassthroughHandle {
|
|
_rc: rc,
|
|
backend,
|
|
public,
|
|
private,
|
|
},
|
|
);
|
|
});
|
|
Ok(())
|
|
}
|
|
|
|
pub fn remove_passthrough(&self, hostname: &InternedString, port: u16) {
|
|
self.passthrough_handles
|
|
.mutate(|h| h.remove(&(hostname.clone(), port)));
|
|
self.gc(Some(hostname.clone()), port);
|
|
}
|
|
|
|
pub fn list_passthrough(&self) -> Vec<PassthroughInfo> {
|
|
self.passthrough_handles.peek(|h| {
|
|
h.iter()
|
|
.map(|((hostname, port), handle)| PassthroughInfo {
|
|
hostname: hostname.clone(),
|
|
listen_port: *port,
|
|
backend: handle.backend,
|
|
public_gateways: handle.public.clone(),
|
|
private_ips: handle.private.clone(),
|
|
})
|
|
.collect()
|
|
})
|
|
}
|
|
|
|
pub fn dump_table(
|
|
&self,
|
|
) -> BTreeMap<JsonKey<u16>, BTreeMap<JsonKey<Option<InternedString>>, EqSet<String>>> {
|
|
self.servers.peek(|s| {
|
|
s.iter()
|
|
.map(|(k, v)| {
|
|
(
|
|
JsonKey::new(*k),
|
|
v.mapping.peek(|m| {
|
|
m.iter()
|
|
.map(|(k, v)| {
|
|
(
|
|
JsonKey::new(k.clone()),
|
|
v.iter()
|
|
.filter(|(_, v)| v.strong_count() > 0)
|
|
.map(|(k, _)| format!("{k:#?}"))
|
|
.collect(),
|
|
)
|
|
})
|
|
.collect()
|
|
}),
|
|
)
|
|
})
|
|
.collect()
|
|
})
|
|
}
|
|
|
|
#[instrument(skip_all)]
|
|
pub fn gc(&self, hostname: Option<InternedString>, external: u16) {
|
|
self.servers.mutate(|writable| {
|
|
if let Some(server) = writable.remove(&external) {
|
|
server.gc(hostname);
|
|
if !server.is_empty() {
|
|
writable.insert(external, server);
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Union of all ProxyTargets' bind requirements for a VHostServer.
|
|
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
|
pub struct VHostBindRequirements {
|
|
pub public_gateways: BTreeSet<GatewayId>,
|
|
pub private_ips: BTreeSet<IpAddr>,
|
|
}
|
|
|
|
fn compute_bind_reqs<A: Accept + 'static>(mapping: &Mapping<A>) -> VHostBindRequirements {
|
|
let mut reqs = VHostBindRequirements::default();
|
|
for (_, targets) in mapping {
|
|
for (target, rc) in targets {
|
|
if rc.strong_count() > 0 {
|
|
let (pub_gw, priv_ip) = target.0.bind_requirements();
|
|
reqs.public_gateways.extend(pub_gw);
|
|
reqs.private_ips.extend(priv_ip);
|
|
}
|
|
}
|
|
}
|
|
reqs
|
|
}
|
|
|
|
/// Listener that manages its own TCP listeners with IP-level precision.
|
|
/// Binds ALL IPs of public gateways and ONLY matching private IPs.
|
|
pub struct VHostBindListener {
|
|
ip_info: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
|
|
port: u16,
|
|
bind_reqs: Watch<VHostBindRequirements>,
|
|
listeners: BTreeMap<SocketAddr, (TcpListener, GatewayInfo)>,
|
|
}
|
|
|
|
fn update_vhost_listeners(
|
|
listeners: &mut BTreeMap<SocketAddr, (TcpListener, GatewayInfo)>,
|
|
port: u16,
|
|
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
|
|
reqs: &VHostBindRequirements,
|
|
) -> Result<(), Error> {
|
|
let mut keep = BTreeSet::<SocketAddr>::new();
|
|
for (gw_id, info) in ip_info {
|
|
if let Some(ip_info) = &info.ip_info {
|
|
for ipnet in &ip_info.subnets {
|
|
let ip = ipnet.addr();
|
|
let should_bind =
|
|
reqs.public_gateways.contains(gw_id) || reqs.private_ips.contains(&ip);
|
|
if should_bind {
|
|
let addr = match ip {
|
|
IpAddr::V6(ip6) => SocketAddrV6::new(
|
|
ip6,
|
|
port,
|
|
0,
|
|
if ipv6_is_link_local(ip6) {
|
|
ip_info.scope_id
|
|
} else {
|
|
0
|
|
},
|
|
)
|
|
.into(),
|
|
ip => SocketAddr::new(ip, port),
|
|
};
|
|
keep.insert(addr);
|
|
if let Some((_, existing_info)) = listeners.get_mut(&addr) {
|
|
*existing_info = GatewayInfo {
|
|
id: gw_id.clone(),
|
|
info: info.clone(),
|
|
};
|
|
} else {
|
|
let tcp = TcpListener::from_std(
|
|
mio::net::TcpListener::bind(addr)
|
|
.with_kind(ErrorKind::Network)?
|
|
.into(),
|
|
)
|
|
.with_kind(ErrorKind::Network)?;
|
|
listeners.insert(
|
|
addr,
|
|
(
|
|
tcp,
|
|
GatewayInfo {
|
|
id: gw_id.clone(),
|
|
info: info.clone(),
|
|
},
|
|
),
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
listeners.retain(|key, _| keep.contains(key));
|
|
Ok(())
|
|
}
|
|
|
|
impl Accept for VHostBindListener {
|
|
type Metadata = NetworkInterfaceListenerAcceptMetadata;
|
|
fn poll_accept(
|
|
&mut self,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
|
|
// Update listeners when ip_info or bind_reqs change
|
|
while self.ip_info.poll_changed(cx).is_ready() || self.bind_reqs.poll_changed(cx).is_ready()
|
|
{
|
|
let reqs = self.bind_reqs.read_and_mark_seen();
|
|
let listeners = &mut self.listeners;
|
|
let port = self.port;
|
|
self.ip_info.peek_and_mark_seen(|ip_info| {
|
|
update_vhost_listeners(listeners, port, ip_info, &reqs)
|
|
})?;
|
|
}
|
|
|
|
// Poll each listener for incoming connections
|
|
for (&addr, (listener, gw_info)) in &self.listeners {
|
|
match listener.poll_accept(cx) {
|
|
Poll::Ready(Ok((stream, peer_addr))) => {
|
|
if let Err(e) = socket2::SockRef::from(&stream).set_keepalive(true) {
|
|
tracing::error!("Failed to set tcp keepalive: {e}");
|
|
tracing::debug!("{e:?}");
|
|
}
|
|
return Poll::Ready(Ok((
|
|
NetworkInterfaceListenerAcceptMetadata {
|
|
inner: TcpMetadata {
|
|
local_addr: addr,
|
|
peer_addr,
|
|
},
|
|
info: gw_info.clone(),
|
|
},
|
|
Box::pin(stream),
|
|
)));
|
|
}
|
|
Poll::Ready(Err(e)) => {
|
|
tracing::trace!("VHostBindListener accept error on {addr}: {e}");
|
|
}
|
|
Poll::Pending => {}
|
|
}
|
|
}
|
|
Poll::Pending
|
|
}
|
|
}
|
|
|
|
pub trait VHostTarget<A: Accept>: std::fmt::Debug + Eq {
|
|
type PreprocessRes: Send + 'static;
|
|
#[allow(unused_variables)]
|
|
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool {
|
|
true
|
|
}
|
|
fn acme(&self) -> Option<&AcmeProvider> {
|
|
None
|
|
}
|
|
/// Returns (public_gateways, private_ips) this target needs the listener to bind on.
|
|
fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>) {
|
|
(BTreeSet::new(), BTreeSet::new())
|
|
}
|
|
fn is_passthrough(&self) -> bool {
|
|
false
|
|
}
|
|
fn preprocess<'a>(
|
|
&'a self,
|
|
prev: ServerConfig,
|
|
hello: &'a ClientHello<'a>,
|
|
metadata: &'a <A as Accept>::Metadata,
|
|
) -> impl Future<Output = Option<(ServerConfig, Self::PreprocessRes)>> + Send + 'a;
|
|
fn handle_stream(
|
|
&self,
|
|
stream: AcceptStream,
|
|
metadata: TlsMetadata<<A as Accept>::Metadata>,
|
|
prev: Self::PreprocessRes,
|
|
rc: Weak<()>,
|
|
);
|
|
}
|
|
|
|
pub trait DynVHostTargetT<A: Accept>: std::fmt::Debug + Any {
|
|
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool;
|
|
fn acme(&self) -> Option<&AcmeProvider>;
|
|
fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>);
|
|
fn is_passthrough(&self) -> bool;
|
|
fn preprocess<'a>(
|
|
&'a self,
|
|
prev: ServerConfig,
|
|
hello: &'a ClientHello<'a>,
|
|
metadata: &'a <A as Accept>::Metadata,
|
|
) -> BoxFuture<'a, Option<(ServerConfig, Box<dyn Any + Send>)>>
|
|
where
|
|
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>;
|
|
fn handle_stream(
|
|
&self,
|
|
stream: AcceptStream,
|
|
metadata: TlsMetadata<<A as Accept>::Metadata>,
|
|
prev: Box<dyn Any + Send>,
|
|
rc: Weak<()>,
|
|
);
|
|
fn eq(&self, other: &dyn DynVHostTargetT<A>) -> bool;
|
|
}
|
|
impl<A: Accept, T: VHostTarget<A> + 'static> DynVHostTargetT<A> for T {
|
|
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool {
|
|
VHostTarget::filter(self, metadata)
|
|
}
|
|
fn acme(&self) -> Option<&AcmeProvider> {
|
|
VHostTarget::acme(self)
|
|
}
|
|
fn is_passthrough(&self) -> bool {
|
|
VHostTarget::is_passthrough(self)
|
|
}
|
|
fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>) {
|
|
VHostTarget::bind_requirements(self)
|
|
}
|
|
fn preprocess<'a>(
|
|
&'a self,
|
|
prev: ServerConfig,
|
|
hello: &'a ClientHello<'a>,
|
|
metadata: &'a <A as Accept>::Metadata,
|
|
) -> BoxFuture<'a, Option<(ServerConfig, Box<dyn Any + Send>)>> {
|
|
VHostTarget::preprocess(self, prev, hello, metadata)
|
|
.map(|o| o.map(|(cfg, res)| (cfg, Box::new(res) as Box<dyn Any + Send>)))
|
|
.boxed()
|
|
}
|
|
fn handle_stream(
|
|
&self,
|
|
stream: AcceptStream,
|
|
metadata: TlsMetadata<<A as Accept>::Metadata>,
|
|
prev: Box<dyn Any + Send>,
|
|
rc: Weak<()>,
|
|
) {
|
|
if let Ok(prev) = prev.downcast() {
|
|
VHostTarget::handle_stream(self, stream, metadata, *prev, rc);
|
|
}
|
|
}
|
|
fn eq(&self, other: &dyn DynVHostTargetT<A>) -> bool {
|
|
Some(self) == (other as &dyn Any).downcast_ref()
|
|
}
|
|
}
|
|
|
|
pub struct DynVHostTarget<A: Accept>(Arc<dyn DynVHostTargetT<A> + Send + Sync>);
|
|
impl<A: Accept> DynVHostTarget<A> {
|
|
pub fn new<T: VHostTarget<A> + Send + Sync + 'static>(target: T) -> Self {
|
|
Self(Arc::new(target))
|
|
}
|
|
}
|
|
impl<A: Accept> Clone for DynVHostTarget<A> {
|
|
fn clone(&self) -> Self {
|
|
Self(self.0.clone())
|
|
}
|
|
}
|
|
impl<A: Accept> std::fmt::Debug for DynVHostTarget<A> {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
self.0.fmt(f)
|
|
}
|
|
}
|
|
impl<A: Accept + 'static> PartialEq for DynVHostTarget<A> {
|
|
fn eq(&self, other: &Self) -> bool {
|
|
self.0.eq(&*other.0)
|
|
}
|
|
}
|
|
impl<A: Accept + 'static> Eq for DynVHostTarget<A> {}
|
|
struct Preprocessed<A: Accept>(DynVHostTarget<A>, Weak<()>, Box<dyn Any + Send>);
|
|
impl<A: Accept> fmt::Debug for Preprocessed<A> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
(self.0).0.fmt(f)
|
|
}
|
|
}
|
|
impl<A: Accept + 'static> DynVHostTarget<A> {
|
|
async fn into_preprocessed(
|
|
self,
|
|
rc: Weak<()>,
|
|
prev: ServerConfig,
|
|
hello: &ClientHello<'_>,
|
|
metadata: &<A as Accept>::Metadata,
|
|
) -> Option<(ServerConfig, Preprocessed<A>)>
|
|
where
|
|
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>,
|
|
{
|
|
let (cfg, res) = self.0.preprocess(prev, hello, metadata).await?;
|
|
Some((cfg, Preprocessed(self, rc, res)))
|
|
}
|
|
}
|
|
impl<A: Accept + 'static> Preprocessed<A> {
|
|
fn finish(self, stream: AcceptStream, metadata: TlsMetadata<<A as Accept>::Metadata>) {
|
|
(self.0).0.handle_stream(stream, metadata, self.2, self.1);
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct ProxyTarget {
|
|
pub public: BTreeSet<GatewayId>,
|
|
pub private: BTreeSet<IpAddr>,
|
|
pub acme: Option<AcmeProvider>,
|
|
pub addr: SocketAddr,
|
|
pub add_x_forwarded_headers: bool,
|
|
pub connect_ssl: Result<Arc<ClientConfig>, AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn
|
|
pub passthrough: bool,
|
|
}
|
|
impl PartialEq for ProxyTarget {
|
|
fn eq(&self, other: &Self) -> bool {
|
|
self.public == other.public
|
|
&& self.private == other.private
|
|
&& self.acme == other.acme
|
|
&& self.addr == other.addr
|
|
&& self.passthrough == other.passthrough
|
|
&& self.connect_ssl.as_ref().map(Arc::as_ptr)
|
|
== other.connect_ssl.as_ref().map(Arc::as_ptr)
|
|
}
|
|
}
|
|
impl Eq for ProxyTarget {}
|
|
impl fmt::Debug for ProxyTarget {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.debug_struct("ProxyTarget")
|
|
.field("public", &self.public)
|
|
.field("private", &self.private)
|
|
.field("acme", &self.acme)
|
|
.field("addr", &self.addr)
|
|
.field("add_x_forwarded_headers", &self.add_x_forwarded_headers)
|
|
.field("connect_ssl", &self.connect_ssl.as_ref().map(|_| ()))
|
|
.field("passthrough", &self.passthrough)
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
impl<A> VHostTarget<A> for ProxyTarget
|
|
where
|
|
A: Accept + 'static,
|
|
<A as Accept>::Metadata: Visit<ExtractVisitor<GatewayInfo>>
|
|
+ Visit<ExtractVisitor<TcpMetadata>>
|
|
+ Clone
|
|
+ Send
|
|
+ Sync,
|
|
{
|
|
type PreprocessRes = AcceptStream;
|
|
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool {
|
|
let gw = extract::<GatewayInfo, _>(metadata);
|
|
let tcp = extract::<TcpMetadata, _>(metadata);
|
|
let (Some(gw), Some(tcp)) = (gw, tcp) else {
|
|
return false;
|
|
};
|
|
let Some(ip_info) = &gw.info.ip_info else {
|
|
return false;
|
|
};
|
|
|
|
let src = tcp.peer_addr.ip();
|
|
// Public: source is outside all known subnets (direct internet)
|
|
let is_public = !ip_info.subnets.iter().any(|s| s.contains(&src));
|
|
|
|
if is_public {
|
|
self.public.contains(&gw.id)
|
|
} else {
|
|
// Private: accept if connection arrived on an interface with a matching IP
|
|
ip_info
|
|
.subnets
|
|
.iter()
|
|
.any(|s| self.private.contains(&s.addr()))
|
|
}
|
|
}
|
|
fn acme(&self) -> Option<&AcmeProvider> {
|
|
self.acme.as_ref()
|
|
}
|
|
fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>) {
|
|
(self.public.clone(), self.private.clone())
|
|
}
|
|
fn is_passthrough(&self) -> bool {
|
|
self.passthrough
|
|
}
|
|
async fn preprocess<'a>(
|
|
&'a self,
|
|
mut prev: ServerConfig,
|
|
hello: &'a ClientHello<'a>,
|
|
_: &'a <A as Accept>::Metadata,
|
|
) -> Option<(ServerConfig, Self::PreprocessRes)> {
|
|
let tcp_stream = TcpStream::connect(self.addr)
|
|
.await
|
|
.with_ctx(|_| (ErrorKind::Network, self.addr))
|
|
.log_err()?;
|
|
if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) {
|
|
tracing::error!("Failed to set tcp keepalive: {e}");
|
|
tracing::debug!("{e:?}");
|
|
}
|
|
match &self.connect_ssl {
|
|
Ok(client_cfg) => {
|
|
let mut client_cfg = (&**client_cfg).clone();
|
|
client_cfg.alpn_protocols = hello
|
|
.alpn()
|
|
.into_iter()
|
|
.flatten()
|
|
.map(|x| x.to_vec())
|
|
.collect();
|
|
let target_stream = TlsConnector::from(Arc::new(client_cfg))
|
|
.connect_with(
|
|
ServerName::IpAddress(self.addr.ip().into()),
|
|
tcp_stream,
|
|
|conn| {
|
|
prev.alpn_protocols
|
|
.extend(conn.alpn_protocol().into_iter().map(|p| p.to_vec()))
|
|
},
|
|
)
|
|
.await
|
|
.log_err()?;
|
|
return Some((prev, Box::pin(target_stream)));
|
|
}
|
|
Err(AlpnInfo::Reflect) => {
|
|
for alpn in hello.alpn().into_iter().flatten() {
|
|
prev.alpn_protocols.push(alpn.to_vec());
|
|
}
|
|
}
|
|
Err(AlpnInfo::Specified(a)) => {
|
|
for alpn in a {
|
|
prev.alpn_protocols.push(alpn.0.clone());
|
|
}
|
|
}
|
|
}
|
|
Some((prev, Box::pin(tcp_stream)))
|
|
}
|
|
fn handle_stream(
|
|
&self,
|
|
mut stream: AcceptStream,
|
|
metadata: TlsMetadata<<A as Accept>::Metadata>,
|
|
mut prev: Self::PreprocessRes,
|
|
rc: Weak<()>,
|
|
) {
|
|
let add_x_forwarded_headers = self.add_x_forwarded_headers;
|
|
tokio::spawn(async move {
|
|
WeakFuture::new(rc, async move {
|
|
if add_x_forwarded_headers {
|
|
crate::net::http::run_http_proxy(
|
|
stream,
|
|
prev,
|
|
metadata.tls_info.alpn,
|
|
extract::<TcpMetadata, _>(&metadata.inner).map(|m| m.peer_addr.ip()),
|
|
)
|
|
.await
|
|
.ok();
|
|
} else {
|
|
tokio::io::copy_bidirectional(&mut stream, &mut prev)
|
|
.await
|
|
.ok();
|
|
}
|
|
})
|
|
.await
|
|
});
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, TS)]
|
|
#[serde(rename_all = "camelCase")]
|
|
#[ts(export)]
|
|
pub enum AlpnInfo {
|
|
Reflect,
|
|
Specified(Vec<MaybeUtf8String>),
|
|
}
|
|
impl Default for AlpnInfo {
|
|
fn default() -> Self {
|
|
Self::Reflect
|
|
}
|
|
}
|
|
|
|
type Mapping<A> = BTreeMap<Option<InternedString>, InOMap<DynVHostTarget<A>, Weak<()>>>;
|
|
|
|
pub struct GetVHostAcmeProvider<A: Accept + 'static>(pub Watch<Mapping<A>>);
|
|
impl<A: Accept + 'static> Clone for GetVHostAcmeProvider<A> {
|
|
fn clone(&self) -> Self {
|
|
Self(self.0.clone())
|
|
}
|
|
}
|
|
impl<A: Accept + 'static> GetAcmeProvider for GetVHostAcmeProvider<A> {
|
|
async fn get_provider<'a, 'b: 'a>(
|
|
&'b self,
|
|
san_info: &'a BTreeSet<InternedString>,
|
|
) -> Option<impl AsRef<AcmeProvider> + Send + 'b> {
|
|
self.0.peek(|m| -> Option<AcmeProvider> {
|
|
san_info
|
|
.iter()
|
|
.fold(Some::<Option<&AcmeProvider>>(None), |acc, x| {
|
|
let acc = acc?;
|
|
if x.parse::<IpAddr>().is_ok() {
|
|
return Some(acc);
|
|
}
|
|
let (t, _) = m
|
|
.get(&Some(x.clone()))?
|
|
.iter()
|
|
.find(|(_, rc)| rc.strong_count() > 0)?;
|
|
let acme = t.0.acme()?;
|
|
Some(if let Some(acc) = acc {
|
|
if acme == acc {
|
|
// all must match
|
|
Some(acme)
|
|
} else {
|
|
None
|
|
}
|
|
} else {
|
|
Some(acme)
|
|
})
|
|
})
|
|
.flatten()
|
|
.cloned()
|
|
})
|
|
}
|
|
}
|
|
|
|
pub struct VHostConnector<A: Accept + 'static>(Watch<Mapping<A>>, Option<Preprocessed<A>>);
|
|
impl<A: Accept + 'static> Clone for VHostConnector<A> {
|
|
fn clone(&self) -> Self {
|
|
Self(self.0.clone(), None)
|
|
}
|
|
}
|
|
|
|
impl<A> WrapTlsHandler<A> for VHostConnector<A>
|
|
where
|
|
A: Accept + 'static,
|
|
<A as Accept>::Metadata:
|
|
Visit<ExtractVisitor<GatewayInfo>> + Visit<ExtractVisitor<TcpMetadata>> + Send + Sync,
|
|
{
|
|
async fn wrap<'a>(
|
|
&'a mut self,
|
|
prev: ServerConfig,
|
|
hello: &'a ClientHello<'a>,
|
|
metadata: &'a <A as Accept>::Metadata,
|
|
) -> Option<TlsHandlerAction>
|
|
where
|
|
Self: 'a,
|
|
{
|
|
if hello
|
|
.alpn()
|
|
.into_iter()
|
|
.flatten()
|
|
.any(|a| a == ACME_TLS_ALPN_NAME)
|
|
{
|
|
return Some(TlsHandlerAction::Tls(prev));
|
|
}
|
|
|
|
let (target, rc) = self.0.peek(|m| {
|
|
m.get(&hello.server_name().map(InternedString::from))
|
|
.or_else(|| m.get(&None))
|
|
.into_iter()
|
|
.flatten()
|
|
.filter(|(_, rc)| rc.strong_count() > 0)
|
|
.find(|(t, _)| t.0.filter(metadata))
|
|
.map(|(t, rc)| (t.clone(), rc.clone()))
|
|
})?;
|
|
|
|
let is_pt = target.0.is_passthrough();
|
|
let (prev, store) = target.into_preprocessed(rc, prev, hello, metadata).await?;
|
|
|
|
self.1 = Some(store);
|
|
|
|
if is_pt {
|
|
Some(TlsHandlerAction::Passthrough)
|
|
} else {
|
|
Some(TlsHandlerAction::Tls(prev))
|
|
}
|
|
}
|
|
}
|
|
|
|
struct VHostListener<M, A>(
|
|
TlsListener<
|
|
A,
|
|
TlsHandlerWrapper<
|
|
ChainedHandler<Arc<AcmeTlsHandler<M, GetVHostAcmeProvider<A>>>, RootCaTlsHandler<M>>,
|
|
VHostConnector<A>,
|
|
>,
|
|
>,
|
|
)
|
|
where
|
|
for<'a> M: HasModel<Model = Model<M>>
|
|
+ DbAccessMut<CertStore>
|
|
+ DbAccessMut<AcmeCertStore>
|
|
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
|
|
+ Send
|
|
+ Sync,
|
|
A: Accept + 'static,
|
|
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
|
|
+ Visit<ExtractVisitor<GatewayInfo>>
|
|
+ Clone
|
|
+ Send
|
|
+ Sync
|
|
+ 'static;
|
|
struct VHostListenerMetadata<A: Accept> {
|
|
inner: TlsMetadata<A::Metadata>,
|
|
preprocessed: Preprocessed<A>,
|
|
}
|
|
impl<A: Accept> fmt::Debug for VHostListenerMetadata<A> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.debug_struct("VHostListenerMetadata")
|
|
.field("inner", &self.inner)
|
|
.field("preprocessed", &self.preprocessed)
|
|
.finish()
|
|
}
|
|
}
|
|
impl<M, A> Accept for VHostListener<M, A>
|
|
where
|
|
for<'a> M: HasModel<Model = Model<M>>
|
|
+ DbAccessMut<CertStore>
|
|
+ DbAccessMut<AcmeCertStore>
|
|
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
|
|
+ Send
|
|
+ Sync
|
|
+ 'static,
|
|
A: Accept + 'static,
|
|
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
|
|
+ Visit<ExtractVisitor<GatewayInfo>>
|
|
+ Clone
|
|
+ Send
|
|
+ Sync
|
|
+ 'static,
|
|
{
|
|
type Metadata = VHostListenerMetadata<A>;
|
|
fn poll_accept(
|
|
&mut self,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
|
|
let (metadata, stream) = ready!(self.0.poll_accept(cx)?);
|
|
let preprocessed = self.0.tls_handler.wrapper.1.take();
|
|
Poll::Ready(Ok((
|
|
VHostListenerMetadata {
|
|
inner: metadata,
|
|
preprocessed: preprocessed.ok_or_else(|| {
|
|
Error::new(
|
|
eyre!("tlslistener yielded but preprocessed isn't set"),
|
|
ErrorKind::Incoherent,
|
|
)
|
|
})?,
|
|
},
|
|
stream,
|
|
)))
|
|
}
|
|
}
|
|
impl<M, A> VHostListener<M, A>
|
|
where
|
|
for<'a> M: HasModel<Model = Model<M>>
|
|
+ DbAccessMut<CertStore>
|
|
+ DbAccessMut<AcmeCertStore>
|
|
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
|
|
+ Send
|
|
+ Sync
|
|
+ 'static,
|
|
A: Accept + 'static,
|
|
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
|
|
+ Visit<ExtractVisitor<GatewayInfo>>
|
|
+ Clone
|
|
+ Send
|
|
+ Sync
|
|
+ 'static,
|
|
{
|
|
async fn handle_next(&mut self) -> Result<(), Error> {
|
|
let (metadata, stream) = futures::future::poll_fn(|cx| self.poll_accept(cx)).await?;
|
|
|
|
metadata.preprocessed.finish(stream, metadata.inner);
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
struct VHostServer<A: Accept + 'static> {
|
|
mapping: Watch<Mapping<A>>,
|
|
bind_reqs: Watch<VHostBindRequirements>,
|
|
_thread: NonDetachingJoinHandle<()>,
|
|
}
|
|
|
|
impl<A: Accept> VHostServer<A> {
|
|
#[instrument(skip_all)]
|
|
fn new<M: HasModel>(
|
|
listener: A,
|
|
bind_reqs: Watch<VHostBindRequirements>,
|
|
db: TypedPatchDb<M>,
|
|
crypto_provider: Arc<CryptoProvider>,
|
|
acme_cache: AcmeTlsAlpnCache,
|
|
) -> Self
|
|
where
|
|
for<'a> M: HasModel<Model = Model<M>>
|
|
+ DbAccessMut<CertStore>
|
|
+ DbAccessMut<AcmeCertStore>
|
|
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
|
|
+ Send
|
|
+ Sync
|
|
+ 'static,
|
|
A: Accept + Send + 'static,
|
|
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
|
|
+ Visit<ExtractVisitor<GatewayInfo>>
|
|
+ Clone
|
|
+ Send
|
|
+ Sync
|
|
+ 'static,
|
|
{
|
|
let mapping = Watch::new(BTreeMap::new());
|
|
Self {
|
|
mapping: mapping.clone(),
|
|
bind_reqs,
|
|
_thread: tokio::spawn(async move {
|
|
let mut listener = VHostListener(TlsListener::new(
|
|
listener,
|
|
TlsHandlerWrapper {
|
|
inner: ChainedHandler(
|
|
Arc::new(AcmeTlsHandler {
|
|
db: db.clone(),
|
|
acme_cache,
|
|
crypto_provider: crypto_provider.clone(),
|
|
get_provider: GetVHostAcmeProvider(mapping.clone()),
|
|
in_progress: Watch::new(BTreeSet::new()),
|
|
}),
|
|
RootCaTlsHandler {
|
|
db,
|
|
crypto_provider,
|
|
},
|
|
),
|
|
wrapper: VHostConnector(mapping, None),
|
|
},
|
|
));
|
|
loop {
|
|
if let Err(e) = listener.handle_next().await {
|
|
tracing::trace!("VHostServer: failed to accept connection: {e}");
|
|
tracing::trace!("{e:?}");
|
|
}
|
|
}
|
|
})
|
|
.into(),
|
|
}
|
|
}
|
|
fn add(
|
|
&self,
|
|
hostname: Option<InternedString>,
|
|
target: DynVHostTarget<A>,
|
|
) -> Result<Arc<()>, Error> {
|
|
let target = target.into();
|
|
let mut res = Ok(Arc::new(()));
|
|
self.mapping.send_if_modified(|writable| {
|
|
let mut changed = false;
|
|
let mut targets = writable.remove(&hostname).unwrap_or_default();
|
|
let rc = if let Some(rc) = Weak::upgrade(&targets.remove(&target).unwrap_or_default()) {
|
|
rc
|
|
} else {
|
|
changed = true;
|
|
Arc::new(())
|
|
};
|
|
targets.retain(|_, rc| rc.strong_count() > 0);
|
|
targets.insert(target, Arc::downgrade(&rc));
|
|
writable.insert(hostname, targets);
|
|
res = Ok(rc);
|
|
if changed {
|
|
self.update_bind_reqs(writable);
|
|
}
|
|
changed
|
|
});
|
|
if self.mapping.watcher_count() > 1 {
|
|
res
|
|
} else {
|
|
Err(Error::new(
|
|
eyre!("VHost Service Thread has exited"),
|
|
crate::ErrorKind::Network,
|
|
))
|
|
}
|
|
}
|
|
fn gc(&self, hostname: Option<InternedString>) {
|
|
self.mapping.send_if_modified(|writable| {
|
|
let mut targets = writable.remove(&hostname).unwrap_or_default();
|
|
let pre = targets.len();
|
|
targets = targets
|
|
.into_iter()
|
|
.filter(|(_, rc)| rc.strong_count() > 0)
|
|
.collect();
|
|
let post = targets.len();
|
|
if !targets.is_empty() {
|
|
writable.insert(hostname, targets);
|
|
}
|
|
if pre != post {
|
|
self.update_bind_reqs(writable);
|
|
}
|
|
pre == post
|
|
});
|
|
}
|
|
fn update_bind_reqs(&self, mapping: &Mapping<A>) {
|
|
let new_reqs = compute_bind_reqs(mapping);
|
|
self.bind_reqs.send_if_modified(|reqs| {
|
|
if *reqs != new_reqs {
|
|
*reqs = new_reqs;
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
});
|
|
}
|
|
fn is_empty(&self) -> bool {
|
|
self.mapping.peek(|m| m.is_empty())
|
|
}
|
|
}
|