passthrough feature

This commit is contained in:
Aiden McClelland
2026-03-04 16:32:21 -07:00
parent 2ed8402edd
commit 0f8a66b357
8 changed files with 460 additions and 96 deletions

View File

@@ -24,7 +24,7 @@ use crate::net::host::Host;
use crate::net::host::binding::{ use crate::net::host::binding::{
AddSslOptions, BindInfo, BindOptions, Bindings, DerivedAddressInfo, NetInfo, AddSslOptions, BindInfo, BindOptions, Bindings, DerivedAddressInfo, NetInfo,
}; };
use crate::net::vhost::AlpnInfo; use crate::net::vhost::{AlpnInfo, PassthroughInfo};
use crate::prelude::*; use crate::prelude::*;
use crate::progress::FullProgress; use crate::progress::FullProgress;
use crate::system::{KeyboardOptions, SmtpValue}; use crate::system::{KeyboardOptions, SmtpValue};
@@ -121,6 +121,7 @@ impl Public {
}, },
dns: Default::default(), dns: Default::default(),
default_outbound: None, default_outbound: None,
passthroughs: Vec::new(),
}, },
status_info: ServerStatus { status_info: ServerStatus {
backup_progress: None, backup_progress: None,
@@ -233,6 +234,8 @@ pub struct NetworkInfo {
#[serde(default)] #[serde(default)]
#[ts(type = "string | null")] #[ts(type = "string | null")]
pub default_outbound: Option<GatewayId>, pub default_outbound: Option<GatewayId>,
#[serde(default)]
pub passthroughs: Vec<PassthroughInfo>,
} }
#[derive(Debug, Default, Deserialize, Serialize, HasModel, TS)] #[derive(Debug, Default, Deserialize, Serialize, HasModel, TS)]

View File

@@ -27,7 +27,7 @@ use crate::db::model::public::AcmeSettings;
use crate::db::{DbAccess, DbAccessByKey, DbAccessMut}; use crate::db::{DbAccess, DbAccessByKey, DbAccessMut};
use crate::error::ErrorData; use crate::error::ErrorData;
use crate::net::ssl::should_use_cert; use crate::net::ssl::should_use_cert;
use crate::net::tls::{SingleCertResolver, TlsHandler}; use crate::net::tls::{SingleCertResolver, TlsHandler, TlsHandlerAction};
use crate::net::web_server::Accept; use crate::net::web_server::Accept;
use crate::prelude::*; use crate::prelude::*;
use crate::util::FromStrParser; use crate::util::FromStrParser;
@@ -173,7 +173,7 @@ where
&'a mut self, &'a mut self,
hello: &'a ClientHello<'a>, hello: &'a ClientHello<'a>,
_: &'a <A as Accept>::Metadata, _: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig> { ) -> Option<TlsHandlerAction> {
let domain = hello.server_name()?; let domain = hello.server_name()?;
if hello if hello
.alpn() .alpn()
@@ -207,20 +207,20 @@ where
cfg.alpn_protocols = vec![ACME_TLS_ALPN_NAME.to_vec()]; cfg.alpn_protocols = vec![ACME_TLS_ALPN_NAME.to_vec()];
tracing::info!("performing ACME auth challenge"); tracing::info!("performing ACME auth challenge");
return Some(cfg); return Some(TlsHandlerAction::Tls(cfg));
} }
let domains: BTreeSet<InternedString> = [domain.into()].into_iter().collect(); let domains: BTreeSet<InternedString> = [domain.into()].into_iter().collect();
let crypto_provider = self.crypto_provider.clone(); let crypto_provider = self.crypto_provider.clone();
if let Some(cert) = self.get_cert(&domains).await { if let Some(cert) = self.get_cert(&domains).await {
return Some( return Some(TlsHandlerAction::Tls(
ServerConfig::builder_with_provider(crypto_provider) ServerConfig::builder_with_provider(crypto_provider)
.with_safe_default_protocol_versions() .with_safe_default_protocol_versions()
.log_err()? .log_err()?
.with_no_client_auth() .with_no_client_auth()
.with_cert_resolver(Arc::new(SingleCertResolver(Arc::new(cert)))), .with_cert_resolver(Arc::new(SingleCertResolver(Arc::new(cert)))),
); ));
} }
None None

View File

@@ -249,6 +249,20 @@ impl Model<Host> {
port: Some(port), port: Some(port),
metadata, metadata,
}); });
} else if opt.secure.map_or(false, |s| s.ssl)
&& opt.add_ssl.is_none()
&& available_ports.is_ssl(opt.preferred_external_port)
&& net.assigned_port != Some(opt.preferred_external_port)
{
// Service handles its own TLS and the preferred port is
// allocated as SSL — add an address for passthrough vhost.
available.insert(HostnameInfo {
ssl: true,
public: true,
hostname: domain,
port: Some(opt.preferred_external_port),
metadata,
});
} }
} }
@@ -293,6 +307,20 @@ impl Model<Host> {
gateways: domain_gateways, gateways: domain_gateways,
}, },
}); });
} else if opt.secure.map_or(false, |s| s.ssl)
&& opt.add_ssl.is_none()
&& available_ports.is_ssl(opt.preferred_external_port)
&& net.assigned_port != Some(opt.preferred_external_port)
{
available.insert(HostnameInfo {
ssl: true,
public: true,
hostname: domain,
port: Some(opt.preferred_external_port),
metadata: HostnameMetadata::PrivateDomain {
gateways: domain_gateways,
},
});
} }
} }
bind.as_addresses_mut().as_available_mut().ser(&available)?; bind.as_addresses_mut().as_available_mut().ser(&available)?;

View File

@@ -76,9 +76,22 @@ impl NetController {
], ],
) )
.await?; .await?;
let passthroughs = db
.peek()
.await
.as_public()
.as_server_info()
.as_network()
.as_passthroughs()
.de()?;
Ok(Self { Ok(Self {
db: db.clone(), db: db.clone(),
vhost: VHostController::new(db.clone(), net_iface.clone(), crypto_provider), vhost: VHostController::new(
db.clone(),
net_iface.clone(),
crypto_provider,
passthroughs,
),
tls_client_config, tls_client_config,
dns: DnsController::init(db, &net_iface.watcher).await?, dns: DnsController::init(db, &net_iface.watcher).await?,
forward: InterfacePortForwardController::new(net_iface.watcher.subscribe()), forward: InterfacePortForwardController::new(net_iface.watcher.subscribe()),
@@ -237,6 +250,7 @@ impl NetServiceData {
connect_ssl: connect_ssl connect_ssl: connect_ssl
.clone() .clone()
.map(|_| ctrl.tls_client_config.clone()), .map(|_| ctrl.tls_client_config.clone()),
passthrough: false,
}, },
); );
} }
@@ -253,7 +267,9 @@ impl NetServiceData {
_ => continue, _ => continue,
} }
let domain = &addr_info.hostname; let domain = &addr_info.hostname;
let domain_ssl_port = addr_info.port.unwrap_or(443); let Some(domain_ssl_port) = addr_info.port else {
continue;
};
let key = (Some(domain.clone()), domain_ssl_port); let key = (Some(domain.clone()), domain_ssl_port);
let target = vhosts.entry(key).or_insert_with(|| ProxyTarget { let target = vhosts.entry(key).or_insert_with(|| ProxyTarget {
public: BTreeSet::new(), public: BTreeSet::new(),
@@ -266,6 +282,7 @@ impl NetServiceData {
addr, addr,
add_x_forwarded_headers: ssl.add_x_forwarded_headers, add_x_forwarded_headers: ssl.add_x_forwarded_headers,
connect_ssl: connect_ssl.clone().map(|_| ctrl.tls_client_config.clone()), connect_ssl: connect_ssl.clone().map(|_| ctrl.tls_client_config.clone()),
passthrough: false,
}); });
if addr_info.public { if addr_info.public {
for gw in addr_info.metadata.gateways() { for gw in addr_info.metadata.gateways() {
@@ -317,6 +334,53 @@ impl NetServiceData {
), ),
); );
} }
// Passthrough vhosts: if the service handles its own TLS
// (secure.ssl && no add_ssl) and a domain address is enabled on
// an SSL port different from assigned_port, add a passthrough
// vhost so the service's TLS endpoint is reachable on that port.
if bind.options.secure.map_or(false, |s| s.ssl) && bind.options.add_ssl.is_none() {
let assigned = bind.net.assigned_port;
for addr_info in &enabled_addresses {
if !addr_info.ssl {
continue;
}
let Some(pt_port) = addr_info.port.filter(|p| assigned != Some(*p)) else {
continue;
};
match &addr_info.metadata {
HostnameMetadata::PublicDomain { .. }
| HostnameMetadata::PrivateDomain { .. } => {}
_ => continue,
}
let domain = &addr_info.hostname;
let key = (Some(domain.clone()), pt_port);
let target = vhosts.entry(key).or_insert_with(|| ProxyTarget {
public: BTreeSet::new(),
private: BTreeSet::new(),
acme: None,
addr,
add_x_forwarded_headers: false,
connect_ssl: Err(AlpnInfo::Reflect),
passthrough: true,
});
if addr_info.public {
for gw in addr_info.metadata.gateways() {
target.public.insert(gw.clone());
}
} else {
for gw in addr_info.metadata.gateways() {
if let Some(info) = net_ifaces.get(gw) {
if let Some(ip_info) = &info.ip_info {
for subnet in &ip_info.subnets {
target.private.insert(subnet.addr());
}
}
}
}
}
}
}
} }
// ── Phase 3: Reconcile ── // ── Phase 3: Reconcile ──

View File

@@ -36,7 +36,7 @@ use crate::db::{DbAccess, DbAccessMut};
use crate::hostname::ServerHostname; use crate::hostname::ServerHostname;
use crate::init::check_time_is_synchronized; use crate::init::check_time_is_synchronized;
use crate::net::gateway::GatewayInfo; use crate::net::gateway::GatewayInfo;
use crate::net::tls::TlsHandler; use crate::net::tls::{TlsHandler, TlsHandlerAction};
use crate::net::web_server::{Accept, ExtractVisitor, TcpMetadata, extract}; use crate::net::web_server::{Accept, ExtractVisitor, TcpMetadata, extract};
use crate::prelude::*; use crate::prelude::*;
use crate::util::serde::Pem; use crate::util::serde::Pem;
@@ -620,7 +620,7 @@ where
&mut self, &mut self,
hello: &ClientHello<'_>, hello: &ClientHello<'_>,
metadata: &<A as Accept>::Metadata, metadata: &<A as Accept>::Metadata,
) -> Option<ServerConfig> { ) -> Option<TlsHandlerAction> {
let hostnames: BTreeSet<InternedString> = hello let hostnames: BTreeSet<InternedString> = hello
.server_name() .server_name()
.map(InternedString::from) .map(InternedString::from)
@@ -684,5 +684,6 @@ where
) )
} }
.log_err() .log_err()
.map(TlsHandlerAction::Tls)
} }
} }

View File

@@ -16,6 +16,14 @@ use tokio_rustls::rustls::sign::CertifiedKey;
use tokio_rustls::rustls::{ClientConfig, RootCertStore, ServerConfig}; use tokio_rustls::rustls::{ClientConfig, RootCertStore, ServerConfig};
use visit_rs::{Visit, VisitFields}; use visit_rs::{Visit, VisitFields};
/// Result of a TLS handler's decision about how to handle a connection.
pub enum TlsHandlerAction {
/// Complete the TLS handshake with this ServerConfig.
Tls(ServerConfig),
/// Don't complete TLS — rewind the BackTrackingIO and return the raw stream.
Passthrough,
}
use crate::net::http::handle_http_on_https; use crate::net::http::handle_http_on_https;
use crate::net::web_server::{Accept, AcceptStream, MetadataVisitor}; use crate::net::web_server::{Accept, AcceptStream, MetadataVisitor};
use crate::prelude::*; use crate::prelude::*;
@@ -50,7 +58,7 @@ pub trait TlsHandler<'a, A: Accept> {
&'a mut self, &'a mut self,
hello: &'a ClientHello<'a>, hello: &'a ClientHello<'a>,
metadata: &'a A::Metadata, metadata: &'a A::Metadata,
) -> impl Future<Output = Option<ServerConfig>> + Send + 'a; ) -> impl Future<Output = Option<TlsHandlerAction>> + Send + 'a;
} }
#[derive(Clone)] #[derive(Clone)]
@@ -66,7 +74,7 @@ where
&'a mut self, &'a mut self,
hello: &'a ClientHello<'a>, hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata, metadata: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig> { ) -> Option<TlsHandlerAction> {
if let Some(config) = self.0.get_config(hello, metadata).await { if let Some(config) = self.0.get_config(hello, metadata).await {
return Some(config); return Some(config);
} }
@@ -86,7 +94,7 @@ pub trait WrapTlsHandler<A: Accept> {
prev: ServerConfig, prev: ServerConfig,
hello: &'a ClientHello<'a>, hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata, metadata: &'a <A as Accept>::Metadata,
) -> impl Future<Output = Option<ServerConfig>> + Send + 'a ) -> impl Future<Output = Option<TlsHandlerAction>> + Send + 'a
where where
Self: 'a; Self: 'a;
} }
@@ -102,9 +110,12 @@ where
&'a mut self, &'a mut self,
hello: &'a ClientHello<'a>, hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata, metadata: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig> { ) -> Option<TlsHandlerAction> {
let prev = self.inner.get_config(hello, metadata).await?; let action = self.inner.get_config(hello, metadata).await?;
self.wrapper.wrap(prev, hello, metadata).await match action {
TlsHandlerAction::Tls(cfg) => self.wrapper.wrap(cfg, hello, metadata).await,
other => Some(other),
}
} }
} }
@@ -203,34 +214,56 @@ where
} }
}; };
let hello = mid.client_hello(); let hello = mid.client_hello();
if let Some(cfg) = tls_handler.get_config(&hello, &metadata).await { let sni = hello.server_name().map(InternedString::intern);
let buffered = mid.io.stop_buffering(); match tls_handler.get_config(&hello, &metadata).await {
mid.io Some(TlsHandlerAction::Tls(cfg)) => {
.write_all(&buffered) let buffered = mid.io.stop_buffering();
.await mid.io
.with_kind(ErrorKind::Network)?; .write_all(&buffered)
return Ok(match mid.into_stream(Arc::new(cfg)).await { .await
Ok(stream) => { .with_kind(ErrorKind::Network)?;
let s = stream.get_ref().1; return Ok(match mid.into_stream(Arc::new(cfg)).await {
Some(( Ok(stream) => {
TlsMetadata { let s = stream.get_ref().1;
inner: metadata, Some((
tls_info: TlsHandshakeInfo { TlsMetadata {
sni: s.server_name().map(InternedString::intern), inner: metadata,
alpn: s tls_info: TlsHandshakeInfo {
.alpn_protocol() sni: s
.map(|a| MaybeUtf8String(a.to_vec())), .server_name()
.map(InternedString::intern),
alpn: s
.alpn_protocol()
.map(|a| MaybeUtf8String(a.to_vec())),
},
}, },
}, Box::pin(stream) as AcceptStream,
Box::pin(stream) as AcceptStream, ))
)) }
} Err(e) => {
Err(e) => { tracing::trace!("Error completing TLS handshake: {e}");
tracing::trace!("Error completing TLS handshake: {e}"); tracing::trace!("{e:?}");
tracing::trace!("{e:?}"); None
None }
} });
}); }
Some(TlsHandlerAction::Passthrough) => {
let (dummy, _drop) = tokio::io::duplex(1);
let mut bt = std::mem::replace(
&mut mid.io,
BackTrackingIO::new(Box::pin(dummy) as AcceptStream),
);
drop(mid);
bt.rewind();
return Ok(Some((
TlsMetadata {
inner: metadata,
tls_info: TlsHandshakeInfo { sni, alpn: None },
},
Box::pin(bt) as AcceptStream,
)));
}
None => {}
} }
Ok(None) Ok(None)

View File

@@ -6,12 +6,13 @@ use std::sync::{Arc, Weak};
use std::task::{Poll, ready}; use std::task::{Poll, ready};
use async_acme::acme::ACME_TLS_ALPN_NAME; use async_acme::acme::ACME_TLS_ALPN_NAME;
use clap::Parser;
use color_eyre::eyre::eyre; use color_eyre::eyre::eyre;
use futures::FutureExt; use futures::FutureExt;
use futures::future::BoxFuture; use futures::future::BoxFuture;
use imbl::OrdMap; use imbl::OrdMap;
use imbl_value::{InOMap, InternedString}; use imbl_value::{InOMap, InternedString};
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn}; use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn, from_fn_async};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::TlsConnector; use tokio_rustls::TlsConnector;
@@ -35,7 +36,7 @@ use crate::net::gateway::{
}; };
use crate::net::ssl::{CertStore, RootCaTlsHandler}; use crate::net::ssl::{CertStore, RootCaTlsHandler};
use crate::net::tls::{ use crate::net::tls::{
ChainedHandler, TlsHandlerWrapper, TlsListener, TlsMetadata, WrapTlsHandler, ChainedHandler, TlsHandlerAction, TlsHandlerWrapper, TlsListener, TlsMetadata, WrapTlsHandler,
}; };
use crate::net::utils::ipv6_is_link_local; use crate::net::utils::ipv6_is_link_local;
use crate::net::web_server::{Accept, AcceptStream, ExtractVisitor, TcpMetadata, extract}; use crate::net::web_server::{Accept, AcceptStream, ExtractVisitor, TcpMetadata, extract};
@@ -46,68 +47,228 @@ use crate::util::serde::{HandlerExtSerde, MaybeUtf8String, display_serializable}
use crate::util::sync::{SyncMutex, Watch}; use crate::util::sync::{SyncMutex, Watch};
use crate::{GatewayId, ResultExt}; use crate::{GatewayId, ResultExt};
#[derive(Debug, Clone, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
#[ts(export)]
pub struct PassthroughInfo {
#[ts(type = "string")]
pub hostname: InternedString,
pub listen_port: u16,
#[ts(type = "string")]
pub backend: SocketAddr,
#[ts(type = "string[]")]
pub public_gateways: BTreeSet<GatewayId>,
#[ts(type = "string[]")]
pub private_ips: BTreeSet<IpAddr>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
#[serde(rename_all = "kebab-case")]
struct AddPassthroughParams {
#[arg(long)]
pub hostname: InternedString,
#[arg(long)]
pub listen_port: u16,
#[arg(long)]
pub backend: SocketAddr,
#[arg(long)]
pub public_gateway: Vec<GatewayId>,
#[arg(long)]
pub private_ip: Vec<IpAddr>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
#[serde(rename_all = "kebab-case")]
struct RemovePassthroughParams {
#[arg(long)]
pub hostname: InternedString,
#[arg(long)]
pub listen_port: u16,
}
pub fn vhost_api<C: Context>() -> ParentHandler<C> { pub fn vhost_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new().subcommand( ParentHandler::new()
"dump-table", .subcommand(
from_fn(|ctx: RpcContext| Ok(ctx.net_controller.vhost.dump_table())) "dump-table",
.with_display_serializable() from_fn(dump_table)
.with_custom_display_fn(|HandlerArgs { params, .. }, res| { .with_display_serializable()
use prettytable::*; .with_custom_display_fn(|HandlerArgs { params, .. }, res| {
use prettytable::*;
if let Some(format) = params.format { if let Some(format) = params.format {
display_serializable(format, res)?; display_serializable(format, res)?;
return Ok::<_, Error>(()); return Ok::<_, Error>(());
} }
let mut table = Table::new(); let mut table = Table::new();
table.add_row(row![bc => "FROM", "TO", "ACTIVE"]); table.add_row(row![bc => "FROM", "TO", "ACTIVE"]);
for (external, targets) in res { for (external, targets) in res {
for (host, targets) in targets { for (host, targets) in targets {
for (idx, target) in targets.into_iter().enumerate() { for (idx, target) in targets.into_iter().enumerate() {
table.add_row(row![ table.add_row(row![
format!( format!(
"{}:{}", "{}:{}",
host.as_ref().map(|s| &**s).unwrap_or("*"), host.as_ref().map(|s| &**s).unwrap_or("*"),
external.0 external.0
), ),
target, target,
idx == 0 idx == 0
]); ]);
}
} }
} }
}
table.print_tty(false)?; table.print_tty(false)?;
Ok(()) Ok(())
}) })
.with_call_remote::<CliContext>(), .with_call_remote::<CliContext>(),
) )
.subcommand(
"add-passthrough",
from_fn_async(add_passthrough)
.no_display()
.with_call_remote::<CliContext>(),
)
.subcommand(
"remove-passthrough",
from_fn_async(remove_passthrough)
.no_display()
.with_call_remote::<CliContext>(),
)
.subcommand(
"list-passthrough",
from_fn(list_passthrough)
.with_display_serializable()
.with_call_remote::<CliContext>(),
)
}
fn dump_table(
ctx: RpcContext,
) -> Result<BTreeMap<JsonKey<u16>, BTreeMap<JsonKey<Option<InternedString>>, EqSet<String>>>, Error>
{
Ok(ctx.net_controller.vhost.dump_table())
}
async fn add_passthrough(
ctx: RpcContext,
AddPassthroughParams {
hostname,
listen_port,
backend,
public_gateway,
private_ip,
}: AddPassthroughParams,
) -> Result<(), Error> {
let public_gateways: BTreeSet<GatewayId> = public_gateway.into_iter().collect();
let private_ips: BTreeSet<IpAddr> = private_ip.into_iter().collect();
ctx.net_controller.vhost.add_passthrough(
hostname.clone(),
listen_port,
backend,
public_gateways.clone(),
private_ips.clone(),
)?;
ctx.db
.mutate(|db| {
let pts = db
.as_public_mut()
.as_server_info_mut()
.as_network_mut()
.as_passthroughs_mut();
let mut vec: Vec<PassthroughInfo> = pts.de()?;
vec.retain(|p| !(p.hostname == hostname && p.listen_port == listen_port));
vec.push(PassthroughInfo {
hostname,
listen_port,
backend,
public_gateways,
private_ips,
});
pts.ser(&vec)
})
.await
.result?;
Ok(())
}
async fn remove_passthrough(
ctx: RpcContext,
RemovePassthroughParams {
hostname,
listen_port,
}: RemovePassthroughParams,
) -> Result<(), Error> {
ctx.net_controller
.vhost
.remove_passthrough(&hostname, listen_port);
ctx.db
.mutate(|db| {
let pts = db
.as_public_mut()
.as_server_info_mut()
.as_network_mut()
.as_passthroughs_mut();
let mut vec: Vec<PassthroughInfo> = pts.de()?;
vec.retain(|p| !(p.hostname == hostname && p.listen_port == listen_port));
pts.ser(&vec)
})
.await
.result?;
Ok(())
}
fn list_passthrough(ctx: RpcContext) -> Result<Vec<PassthroughInfo>, Error> {
Ok(ctx.net_controller.vhost.list_passthrough())
} }
// not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353 // not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353
struct PassthroughHandle {
_rc: Arc<()>,
backend: SocketAddr,
public: BTreeSet<GatewayId>,
private: BTreeSet<IpAddr>,
}
pub struct VHostController { pub struct VHostController {
db: TypedPatchDb<Database>, db: TypedPatchDb<Database>,
interfaces: Arc<NetworkInterfaceController>, interfaces: Arc<NetworkInterfaceController>,
crypto_provider: Arc<CryptoProvider>, crypto_provider: Arc<CryptoProvider>,
acme_cache: AcmeTlsAlpnCache, acme_cache: AcmeTlsAlpnCache,
servers: SyncMutex<BTreeMap<u16, VHostServer<VHostBindListener>>>, servers: SyncMutex<BTreeMap<u16, VHostServer<VHostBindListener>>>,
passthrough_handles: SyncMutex<BTreeMap<(InternedString, u16), PassthroughHandle>>,
} }
impl VHostController { impl VHostController {
pub fn new( pub fn new(
db: TypedPatchDb<Database>, db: TypedPatchDb<Database>,
interfaces: Arc<NetworkInterfaceController>, interfaces: Arc<NetworkInterfaceController>,
crypto_provider: Arc<CryptoProvider>, crypto_provider: Arc<CryptoProvider>,
passthroughs: Vec<PassthroughInfo>,
) -> Self { ) -> Self {
Self { let controller = Self {
db, db,
interfaces, interfaces,
crypto_provider, crypto_provider,
acme_cache: Arc::new(SyncMutex::new(BTreeMap::new())), acme_cache: Arc::new(SyncMutex::new(BTreeMap::new())),
servers: SyncMutex::new(BTreeMap::new()), servers: SyncMutex::new(BTreeMap::new()),
passthrough_handles: SyncMutex::new(BTreeMap::new()),
};
for pt in passthroughs {
if let Err(e) = controller.add_passthrough(
pt.hostname,
pt.listen_port,
pt.backend,
pt.public_gateways,
pt.private_ips,
) {
tracing::warn!("failed to restore passthrough: {e}");
}
} }
controller
} }
#[instrument(skip_all)] #[instrument(skip_all)]
pub fn add( pub fn add(
@@ -120,20 +281,7 @@ impl VHostController {
let server = if let Some(server) = writable.remove(&external) { let server = if let Some(server) = writable.remove(&external) {
server server
} else { } else {
let bind_reqs = Watch::new(VHostBindRequirements::default()); self.create_server(external)
let listener = VHostBindListener {
ip_info: self.interfaces.watcher.subscribe(),
port: external,
bind_reqs: bind_reqs.clone_unseen(),
listeners: BTreeMap::new(),
};
VHostServer::new(
listener,
bind_reqs,
self.db.clone(),
self.crypto_provider.clone(),
self.acme_cache.clone(),
)
}; };
let rc = server.add(hostname, target); let rc = server.add(hostname, target);
writable.insert(external, server); writable.insert(external, server);
@@ -141,6 +289,75 @@ impl VHostController {
}) })
} }
fn create_server(&self, port: u16) -> VHostServer<VHostBindListener> {
let bind_reqs = Watch::new(VHostBindRequirements::default());
let listener = VHostBindListener {
ip_info: self.interfaces.watcher.subscribe(),
port,
bind_reqs: bind_reqs.clone_unseen(),
listeners: BTreeMap::new(),
};
VHostServer::new(
listener,
bind_reqs,
self.db.clone(),
self.crypto_provider.clone(),
self.acme_cache.clone(),
)
}
pub fn add_passthrough(
&self,
hostname: InternedString,
port: u16,
backend: SocketAddr,
public: BTreeSet<GatewayId>,
private: BTreeSet<IpAddr>,
) -> Result<(), Error> {
let target = ProxyTarget {
public: public.clone(),
private: private.clone(),
acme: None,
addr: backend,
add_x_forwarded_headers: false,
connect_ssl: Err(AlpnInfo::Reflect),
passthrough: true,
};
let rc = self.add(Some(hostname.clone()), port, DynVHostTarget::new(target))?;
self.passthrough_handles.mutate(|h| {
h.insert(
(hostname, port),
PassthroughHandle {
_rc: rc,
backend,
public,
private,
},
);
});
Ok(())
}
pub fn remove_passthrough(&self, hostname: &InternedString, port: u16) {
self.passthrough_handles
.mutate(|h| h.remove(&(hostname.clone(), port)));
self.gc(Some(hostname.clone()), port);
}
pub fn list_passthrough(&self) -> Vec<PassthroughInfo> {
self.passthrough_handles.peek(|h| {
h.iter()
.map(|((hostname, port), handle)| PassthroughInfo {
hostname: hostname.clone(),
listen_port: *port,
backend: handle.backend,
public_gateways: handle.public.clone(),
private_ips: handle.private.clone(),
})
.collect()
})
}
pub fn dump_table( pub fn dump_table(
&self, &self,
) -> BTreeMap<JsonKey<u16>, BTreeMap<JsonKey<Option<InternedString>>, EqSet<String>>> { ) -> BTreeMap<JsonKey<u16>, BTreeMap<JsonKey<Option<InternedString>>, EqSet<String>>> {
@@ -330,6 +547,9 @@ pub trait VHostTarget<A: Accept>: std::fmt::Debug + Eq {
fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>) { fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>) {
(BTreeSet::new(), BTreeSet::new()) (BTreeSet::new(), BTreeSet::new())
} }
fn is_passthrough(&self) -> bool {
false
}
fn preprocess<'a>( fn preprocess<'a>(
&'a self, &'a self,
prev: ServerConfig, prev: ServerConfig,
@@ -349,6 +569,7 @@ pub trait DynVHostTargetT<A: Accept>: std::fmt::Debug + Any {
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool; fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool;
fn acme(&self) -> Option<&AcmeProvider>; fn acme(&self) -> Option<&AcmeProvider>;
fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>); fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>);
fn is_passthrough(&self) -> bool;
fn preprocess<'a>( fn preprocess<'a>(
&'a self, &'a self,
prev: ServerConfig, prev: ServerConfig,
@@ -373,6 +594,9 @@ impl<A: Accept, T: VHostTarget<A> + 'static> DynVHostTargetT<A> for T {
fn acme(&self) -> Option<&AcmeProvider> { fn acme(&self) -> Option<&AcmeProvider> {
VHostTarget::acme(self) VHostTarget::acme(self)
} }
fn is_passthrough(&self) -> bool {
VHostTarget::is_passthrough(self)
}
fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>) { fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>) {
VHostTarget::bind_requirements(self) VHostTarget::bind_requirements(self)
} }
@@ -459,6 +683,7 @@ pub struct ProxyTarget {
pub addr: SocketAddr, pub addr: SocketAddr,
pub add_x_forwarded_headers: bool, pub add_x_forwarded_headers: bool,
pub connect_ssl: Result<Arc<ClientConfig>, AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn pub connect_ssl: Result<Arc<ClientConfig>, AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn
pub passthrough: bool,
} }
impl PartialEq for ProxyTarget { impl PartialEq for ProxyTarget {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
@@ -466,6 +691,7 @@ impl PartialEq for ProxyTarget {
&& self.private == other.private && self.private == other.private
&& self.acme == other.acme && self.acme == other.acme
&& self.addr == other.addr && self.addr == other.addr
&& self.passthrough == other.passthrough
&& self.connect_ssl.as_ref().map(Arc::as_ptr) && self.connect_ssl.as_ref().map(Arc::as_ptr)
== other.connect_ssl.as_ref().map(Arc::as_ptr) == other.connect_ssl.as_ref().map(Arc::as_ptr)
} }
@@ -480,6 +706,7 @@ impl fmt::Debug for ProxyTarget {
.field("addr", &self.addr) .field("addr", &self.addr)
.field("add_x_forwarded_headers", &self.add_x_forwarded_headers) .field("add_x_forwarded_headers", &self.add_x_forwarded_headers)
.field("connect_ssl", &self.connect_ssl.as_ref().map(|_| ())) .field("connect_ssl", &self.connect_ssl.as_ref().map(|_| ()))
.field("passthrough", &self.passthrough)
.finish() .finish()
} }
} }
@@ -524,6 +751,9 @@ where
fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>) { fn bind_requirements(&self) -> (BTreeSet<GatewayId>, BTreeSet<IpAddr>) {
(self.public.clone(), self.private.clone()) (self.public.clone(), self.private.clone())
} }
fn is_passthrough(&self) -> bool {
self.passthrough
}
async fn preprocess<'a>( async fn preprocess<'a>(
&'a self, &'a self,
mut prev: ServerConfig, mut prev: ServerConfig,
@@ -677,7 +907,7 @@ where
prev: ServerConfig, prev: ServerConfig,
hello: &'a ClientHello<'a>, hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata, metadata: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig> ) -> Option<TlsHandlerAction>
where where
Self: 'a, Self: 'a,
{ {
@@ -687,7 +917,7 @@ where
.flatten() .flatten()
.any(|a| a == ACME_TLS_ALPN_NAME) .any(|a| a == ACME_TLS_ALPN_NAME)
{ {
return Some(prev); return Some(TlsHandlerAction::Tls(prev));
} }
let (target, rc) = self.0.peek(|m| { let (target, rc) = self.0.peek(|m| {
@@ -700,11 +930,16 @@ where
.map(|(t, rc)| (t.clone(), rc.clone())) .map(|(t, rc)| (t.clone(), rc.clone()))
})?; })?;
let is_pt = target.0.is_passthrough();
let (prev, store) = target.into_preprocessed(rc, prev, hello, metadata).await?; let (prev, store) = target.into_preprocessed(rc, prev, hello, metadata).await?;
self.1 = Some(store); self.1 = Some(store);
Some(prev) if is_pt {
Some(TlsHandlerAction::Passthrough)
} else {
Some(TlsHandlerAction::Tls(prev))
}
} }
} }

View File

@@ -20,7 +20,7 @@ use ts_rs::TS;
use crate::context::CliContext; use crate::context::CliContext;
use crate::hostname::ServerHostname; use crate::hostname::ServerHostname;
use crate::net::ssl::{SANInfo, root_ca_start_time}; use crate::net::ssl::{SANInfo, root_ca_start_time};
use crate::net::tls::TlsHandler; use crate::net::tls::{TlsHandler, TlsHandlerAction};
use crate::net::web_server::Accept; use crate::net::web_server::Accept;
use crate::prelude::*; use crate::prelude::*;
use crate::tunnel::auth::SetPasswordParams; use crate::tunnel::auth::SetPasswordParams;
@@ -59,7 +59,7 @@ where
&'a mut self, &'a mut self,
_: &'a ClientHello<'a>, _: &'a ClientHello<'a>,
_: &'a <A as Accept>::Metadata, _: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig> { ) -> Option<TlsHandlerAction> {
let cert_info = self let cert_info = self
.db .db
.peek() .peek()
@@ -88,7 +88,7 @@ where
.log_err()?; .log_err()?;
cfg.alpn_protocols cfg.alpn_protocols
.extend([b"http/1.1".into(), b"h2".into()]); .extend([b"http/1.1".into(), b"h2".into()]);
Some(cfg) Some(TlsHandlerAction::Tls(cfg))
} }
} }