diff --git a/build/lib/scripts/forward-port b/build/lib/scripts/forward-port index 705c1e6a7..31f42a39e 100755 --- a/build/lib/scripts/forward-port +++ b/build/lib/scripts/forward-port @@ -5,7 +5,7 @@ if [ -z "$sip" ] || [ -z "$dip" ] || [ -z "$dprefix" ] || [ -z "$sport" ] || [ - exit 1 fi -NAME="F$(echo "$sip:$sport -> $dip/$dprefix:$dport" | sha256sum | head -c 15)" +NAME="F$(echo "$sip:$sport -> $dip/$dprefix:$dport ${src_subnet:-any} ${excluded_src:-none}" | sha256sum | head -c 15)" for kind in INPUT FORWARD ACCEPT; do if ! iptables -C $kind -j "${NAME}_${kind}" 2> /dev/null; then @@ -36,8 +36,22 @@ if [ "$UNDO" = 1 ]; then fi # DNAT: rewrite destination for incoming packets (external traffic) -iptables -t nat -A ${NAME}_PREROUTING -d "$sip" -p tcp --dport "$sport" -j DNAT --to-destination "$dip:$dport" -iptables -t nat -A ${NAME}_PREROUTING -d "$sip" -p udp --dport "$sport" -j DNAT --to-destination "$dip:$dport" +# When src_subnet is set, only forward traffic from that subnet (private forwards) +# excluded_src: comma-separated gateway/router IPs to reject (they may masquerade internet traffic) +if [ -n "$src_subnet" ]; then + if [ -n "$excluded_src" ]; then + IFS=',' read -ra EXCLUDED <<< "$excluded_src" + for excl in "${EXCLUDED[@]}"; do + iptables -t nat -A ${NAME}_PREROUTING -s "$excl" -d "$sip" -p tcp --dport "$sport" -j RETURN + iptables -t nat -A ${NAME}_PREROUTING -s "$excl" -d "$sip" -p udp --dport "$sport" -j RETURN + done + fi + iptables -t nat -A ${NAME}_PREROUTING -s "$src_subnet" -d "$sip" -p tcp --dport "$sport" -j DNAT --to-destination "$dip:$dport" + iptables -t nat -A ${NAME}_PREROUTING -s "$src_subnet" -d "$sip" -p udp --dport "$sport" -j DNAT --to-destination "$dip:$dport" +else + iptables -t nat -A ${NAME}_PREROUTING -d "$sip" -p tcp --dport "$sport" -j DNAT --to-destination "$dip:$dport" + iptables -t nat -A ${NAME}_PREROUTING -d "$sip" -p udp --dport "$sport" -j DNAT --to-destination "$dip:$dport" +fi # DNAT: rewrite destination for locally-originated packets (hairpin from host itself) iptables -t nat -A ${NAME}_OUTPUT -d "$sip" -p tcp --dport "$sport" -j DNAT --to-destination "$dip:$dport" @@ -52,4 +66,4 @@ iptables -t nat -A ${NAME}_POSTROUTING -d "$dip" -p udp --dport "$dport" -j MASQ iptables -A ${NAME}_FORWARD -d $dip -p tcp --dport $dport -m state --state NEW -j ACCEPT iptables -A ${NAME}_FORWARD -d $dip -p udp --dport $dport -m state --state NEW -j ACCEPT -exit $err \ No newline at end of file +exit $err diff --git a/core/src/bins/start_init.rs b/core/src/bins/start_init.rs index 48e65f5af..5c53a6e0c 100644 --- a/core/src/bins/start_init.rs +++ b/core/src/bins/start_init.rs @@ -9,7 +9,7 @@ use crate::disk::fsck::RepairStrategy; use crate::disk::main::DEFAULT_PASSWORD; use crate::firmware::{check_for_firmware_update, update_firmware}; use crate::init::{InitPhases, STANDBY_MODE_PATH}; -use crate::net::gateway::UpgradableListener; +use crate::net::gateway::WildcardListener; use crate::net::web_server::WebServer; use crate::prelude::*; use crate::progress::FullProgressTracker; @@ -19,7 +19,7 @@ use crate::{DATA_DIR, PLATFORM}; #[instrument(skip_all)] async fn setup_or_init( - server: &mut WebServer, + server: &mut WebServer, config: &ServerConfig, ) -> Result, Error> { if let Some(firmware) = check_for_firmware_update() @@ -204,7 +204,7 @@ async fn setup_or_init( #[instrument(skip_all)] pub async fn main( - server: &mut WebServer, + server: &mut WebServer, config: &ServerConfig, ) -> Result, Error> { if &*PLATFORM == "raspberrypi" && tokio::fs::metadata(STANDBY_MODE_PATH).await.is_ok() { diff --git a/core/src/bins/startd.rs b/core/src/bins/startd.rs index f4a7784f4..b88f622e5 100644 --- a/core/src/bins/startd.rs +++ b/core/src/bins/startd.rs @@ -12,7 +12,7 @@ use tracing::instrument; use crate::context::config::ServerConfig; use crate::context::rpc::InitRpcContextPhases; use crate::context::{DiagnosticContext, InitContext, RpcContext}; -use crate::net::gateway::{BindTcp, SelfContainedNetworkInterfaceListener, UpgradableListener}; +use crate::net::gateway::WildcardListener; use crate::net::static_server::refresher; use crate::net::web_server::{Acceptor, WebServer}; use crate::prelude::*; @@ -23,7 +23,7 @@ use crate::util::logger::LOGGER; #[instrument(skip_all)] async fn inner_main( - server: &mut WebServer, + server: &mut WebServer, config: &ServerConfig, ) -> Result, Error> { let rpc_ctx = if !tokio::fs::metadata("/run/startos/initialized") @@ -148,7 +148,7 @@ pub fn main(args: impl IntoIterator) { .expect(&t!("bins.startd.failed-to-initialize-runtime")); let res = rt.block_on(async { let mut server = WebServer::new( - Acceptor::bind_upgradable(SelfContainedNetworkInterfaceListener::bind(BindTcp, 80)), + Acceptor::new(WildcardListener::new(80)?), refresher(), ); match inner_main(&mut server, &config).await { diff --git a/core/src/bins/tunnel.rs b/core/src/bins/tunnel.rs index 97fb818ea..57615c00c 100644 --- a/core/src/bins/tunnel.rs +++ b/core/src/bins/tunnel.rs @@ -13,7 +13,7 @@ use visit_rs::Visit; use crate::context::CliContext; use crate::context::config::ClientConfig; -use crate::net::gateway::{Bind, BindTcp}; +use tokio::net::TcpListener; use crate::net::tls::TlsListener; use crate::net::web_server::{Accept, Acceptor, MetadataVisitor, WebServer}; use crate::prelude::*; @@ -57,7 +57,12 @@ async fn inner_main(config: &TunnelConfig) -> Result<(), Error> { if !a.contains_key(&key) { match (|| { Ok::<_, Error>(TlsListener::new( - BindTcp.bind(addr)?, + TcpListener::from_std( + mio::net::TcpListener::bind(addr) + .with_kind(ErrorKind::Network)? + .into(), + ) + .with_kind(ErrorKind::Network)?, TunnelCertHandler { db: https_db.clone(), crypto_provider: Arc::new(tokio_rustls::rustls::crypto::ring::default_provider()), diff --git a/core/src/context/rpc.rs b/core/src/context/rpc.rs index a59d60236..6350672f1 100644 --- a/core/src/context/rpc.rs +++ b/core/src/context/rpc.rs @@ -34,7 +34,7 @@ use crate::disk::mount::guard::MountGuard; use crate::init::{InitResult, check_time_is_synchronized}; use crate::install::PKG_ARCHIVE_DIR; use crate::lxc::LxcManager; -use crate::net::gateway::UpgradableListener; +use crate::net::gateway::WildcardListener; use crate::net::net_controller::{NetController, NetService}; use crate::net::socks::DEFAULT_SOCKS_LISTEN; use crate::net::utils::{find_eth_iface, find_wifi_iface}; @@ -132,7 +132,7 @@ pub struct RpcContext(Arc); impl RpcContext { #[instrument(skip_all)] pub async fn init( - webserver: &WebServerAcceptorSetter, + webserver: &WebServerAcceptorSetter, config: &ServerConfig, disk_guid: InternedString, init_result: Option, @@ -167,7 +167,7 @@ impl RpcContext { } else { let net_ctrl = Arc::new(NetController::init(db.clone(), &account.hostname, socks_proxy).await?); - webserver.try_upgrade(|a| net_ctrl.net_iface.watcher.upgrade_listener(a))?; + webserver.send_modify(|wl| wl.set_ip_info(net_ctrl.net_iface.watcher.subscribe())); let os_net_service = net_ctrl.os_bindings().await?; (net_ctrl, os_net_service) }; diff --git a/core/src/context/setup.rs b/core/src/context/setup.rs index bbfee9862..39a6b5fd3 100644 --- a/core/src/context/setup.rs +++ b/core/src/context/setup.rs @@ -20,7 +20,7 @@ use crate::context::RpcContext; use crate::context::config::ServerConfig; use crate::disk::mount::guard::{MountGuard, TmpMountGuard}; use crate::hostname::Hostname; -use crate::net::gateway::UpgradableListener; +use crate::net::gateway::WildcardListener; use crate::net::web_server::{WebServer, WebServerAcceptorSetter}; use crate::prelude::*; use crate::progress::FullProgressTracker; @@ -51,7 +51,7 @@ pub struct SetupResult { } pub struct SetupContextSeed { - pub webserver: WebServerAcceptorSetter, + pub webserver: WebServerAcceptorSetter, pub config: SyncMutex, pub disable_encryption: bool, pub progress: FullProgressTracker, @@ -70,7 +70,7 @@ pub struct SetupContext(Arc); impl SetupContext { #[instrument(skip_all)] pub fn init( - webserver: &WebServer, + webserver: &WebServer, config: ServerConfig, ) -> Result { let (shutdown, _) = tokio::sync::broadcast::channel(1); diff --git a/core/src/init.rs b/core/src/init.rs index 39680015e..e9507ef49 100644 --- a/core/src/init.rs +++ b/core/src/init.rs @@ -20,7 +20,7 @@ use crate::db::model::public::ServerStatus; use crate::developer::OS_DEVELOPER_KEY_PATH; use crate::hostname::Hostname; use crate::middleware::auth::local::LocalAuthContext; -use crate::net::gateway::UpgradableListener; +use crate::net::gateway::WildcardListener; use crate::net::net_controller::{NetController, NetService}; use crate::net::socks::DEFAULT_SOCKS_LISTEN; use crate::net::utils::find_wifi_iface; @@ -144,7 +144,7 @@ pub async fn run_script>(path: P, mut progress: PhaseProgressTrac #[instrument(skip_all)] pub async fn init( - webserver: &WebServerAcceptorSetter, + webserver: &WebServerAcceptorSetter, cfg: &ServerConfig, InitPhases { preinit, @@ -218,7 +218,7 @@ pub async fn init( ) .await?, ); - webserver.try_upgrade(|a| net_ctrl.net_iface.watcher.upgrade_listener(a))?; + webserver.send_modify(|wl| wl.set_ip_info(net_ctrl.net_iface.watcher.subscribe())); let os_net_service = net_ctrl.os_bindings().await?; start_net.complete(); diff --git a/core/src/net/forward.rs b/core/src/net/forward.rs index 7cf5e7985..b9fe30a37 100644 --- a/core/src/net/forward.rs +++ b/core/src/net/forward.rs @@ -15,7 +15,6 @@ use tokio::sync::mpsc; use crate::GatewayId; use crate::context::{CliContext, RpcContext}; use crate::db::model::public::NetworkInterfaceInfo; -use crate::net::gateway::{DynInterfaceFilter, InterfaceFilter}; use crate::prelude::*; use crate::util::Invoke; use crate::util::future::NonDetachingJoinHandle; @@ -31,6 +30,33 @@ fn is_restricted(port: u16) -> bool { port <= 1024 || RESTRICTED_PORTS.contains(&port) } +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct ForwardRequirements { + pub public_gateways: BTreeSet, + pub private_ips: BTreeSet, + pub secure: bool, +} + +impl std::fmt::Display for ForwardRequirements { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ForwardRequirements {{ public: {:?}, private: {:?}, secure: {} }}", + self.public_gateways, self.private_ips, self.secure + ) + } +} + +/// Source-IP filter for private forwards: restricts traffic to a subnet +/// while excluding gateway/router IPs that may masquerade internet traffic. +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct SourceFilter { + /// Network CIDR to allow (e.g. "192.168.1.0/24") + subnet: String, + /// Comma-separated gateway IPs to exclude (they may masquerade internet traffic) + excluded: String, +} + #[derive(Debug, Deserialize, Serialize)] pub struct AvailablePorts(BTreeMap); impl AvailablePorts { @@ -85,10 +111,10 @@ pub fn forward_api() -> ParentHandler { } let mut table = Table::new(); - table.add_row(row![bc => "FROM", "TO", "FILTER"]); + table.add_row(row![bc => "FROM", "TO", "REQS"]); for (external, target) in res.0 { - table.add_row(row![external, target.target, target.filter]); + table.add_row(row![external, target.target, target.reqs]); } table.print_tty(false)?; @@ -103,6 +129,7 @@ struct ForwardMapping { source: SocketAddrV4, target: SocketAddrV4, target_prefix: u8, + src_filter: Option, rc: Weak<()>, } @@ -117,9 +144,10 @@ impl PortForwardState { source: SocketAddrV4, target: SocketAddrV4, target_prefix: u8, + src_filter: Option, ) -> Result, Error> { if let Some(existing) = self.mappings.get_mut(&source) { - if existing.target == target { + if existing.target == target && existing.src_filter == src_filter { if let Some(existing_rc) = existing.rc.upgrade() { return Ok(existing_rc); } else { @@ -128,21 +156,28 @@ impl PortForwardState { return Ok(rc); } } else { - // Different target, need to remove old and add new + // Different target or src_filter, need to remove old and add new if let Some(mapping) = self.mappings.remove(&source) { - unforward(mapping.source, mapping.target, mapping.target_prefix).await?; + unforward( + mapping.source, + mapping.target, + mapping.target_prefix, + mapping.src_filter.as_ref(), + ) + .await?; } } } let rc = Arc::new(()); - forward(source, target, target_prefix).await?; + forward(source, target, target_prefix, src_filter.as_ref()).await?; self.mappings.insert( source, ForwardMapping { source, target, target_prefix, + src_filter, rc: Arc::downgrade(&rc), }, ); @@ -160,7 +195,13 @@ impl PortForwardState { for source in to_remove { if let Some(mapping) = self.mappings.remove(&source) { - unforward(mapping.source, mapping.target, mapping.target_prefix).await?; + unforward( + mapping.source, + mapping.target, + mapping.target_prefix, + mapping.src_filter.as_ref(), + ) + .await?; } } Ok(()) @@ -181,9 +222,14 @@ impl Drop for PortForwardState { let mappings = std::mem::take(&mut self.mappings); tokio::spawn(async move { for (_, mapping) in mappings { - unforward(mapping.source, mapping.target, mapping.target_prefix) - .await - .log_err(); + unforward( + mapping.source, + mapping.target, + mapping.target_prefix, + mapping.src_filter.as_ref(), + ) + .await + .log_err(); } }); } @@ -195,6 +241,7 @@ enum PortForwardCommand { source: SocketAddrV4, target: SocketAddrV4, target_prefix: u8, + src_filter: Option, respond: oneshot::Sender, Error>>, }, Gc { @@ -281,9 +328,12 @@ impl PortForwardController { source, target, target_prefix, + src_filter, respond, } => { - let result = state.add_forward(source, target, target_prefix).await; + let result = state + .add_forward(source, target, target_prefix, src_filter) + .await; respond.send(result).ok(); } PortForwardCommand::Gc { respond } => { @@ -308,6 +358,7 @@ impl PortForwardController { source: SocketAddrV4, target: SocketAddrV4, target_prefix: u8, + src_filter: Option, ) -> Result, Error> { let (send, recv) = oneshot::channel(); self.req @@ -315,6 +366,7 @@ impl PortForwardController { source, target, target_prefix, + src_filter, respond: send, }) .map_err(err_has_exited)?; @@ -345,14 +397,14 @@ struct InterfaceForwardRequest { external: u16, target: SocketAddrV4, target_prefix: u8, - filter: DynInterfaceFilter, + reqs: ForwardRequirements, rc: Arc<()>, } #[derive(Clone)] struct InterfaceForwardEntry { external: u16, - filter: BTreeMap)>, + targets: BTreeMap)>, // Maps source SocketAddr -> strong reference for the forward created in PortForwardController forwards: BTreeMap>, } @@ -370,7 +422,7 @@ impl InterfaceForwardEntry { fn new(external: u16) -> Self { Self { external, - filter: BTreeMap::new(), + targets: BTreeMap::new(), forwards: BTreeMap::new(), } } @@ -382,28 +434,50 @@ impl InterfaceForwardEntry { ) -> Result<(), Error> { let mut keep = BTreeSet::::new(); - for (iface, info) in ip_info.iter() { - if let Some((target, target_prefix)) = self - .filter - .iter() - .filter(|(_, (_, _, rc))| rc.strong_count() > 0) - .find(|(filter, _)| filter.filter(iface, info)) - .map(|(_, (target, target_prefix, _))| (*target, *target_prefix)) - { - if let Some(ip_info) = &info.ip_info { - for addr in ip_info.subnets.iter().filter_map(|net| { - if let IpAddr::V4(ip) = net.addr() { - Some(SocketAddrV4::new(ip, self.external)) - } else { - None + for (gw_id, info) in ip_info.iter() { + if let Some(ip_info) = &info.ip_info { + for subnet in ip_info.subnets.iter() { + if let IpAddr::V4(ip) = subnet.addr() { + let addr = SocketAddrV4::new(ip, self.external); + if keep.contains(&addr) { + continue; } - }) { - keep.insert(addr); - if !self.forwards.contains_key(&addr) { - let rc = port_forward - .add_forward(addr, target, target_prefix) + + for (reqs, (target, target_prefix, rc)) in self.targets.iter() { + if rc.strong_count() == 0 { + continue; + } + if !reqs.secure && !info.secure() { + continue; + } + + let src_filter = + if reqs.public_gateways.contains(gw_id) { + None + } else if reqs.private_ips.contains(&IpAddr::V4(ip)) { + let excluded = ip_info + .lan_ip + .iter() + .filter_map(|ip| match ip { + IpAddr::V4(v4) => Some(v4.to_string()), + _ => None, + }) + .collect::>() + .join(","); + Some(SourceFilter { + subnet: subnet.trunc().to_string(), + excluded, + }) + } else { + continue; + }; + + keep.insert(addr); + let fwd_rc = port_forward + .add_forward(addr, *target, *target_prefix, src_filter) .await?; - self.forwards.insert(addr, rc); + self.forwards.insert(addr, fwd_rc); + break; } } } @@ -422,7 +496,7 @@ impl InterfaceForwardEntry { external, target, target_prefix, - filter, + reqs, mut rc, }: InterfaceForwardRequest, ip_info: &OrdMap, @@ -436,8 +510,8 @@ impl InterfaceForwardEntry { } let entry = self - .filter - .entry(filter) + .targets + .entry(reqs) .or_insert_with(|| (target, target_prefix, Arc::downgrade(&rc))); if entry.0 != target { entry.0 = target; @@ -460,7 +534,7 @@ impl InterfaceForwardEntry { ip_info: &OrdMap, port_forward: &PortForwardController, ) -> Result<(), Error> { - self.filter.retain(|_, (_, _, rc)| rc.strong_count() > 0); + self.targets.retain(|_, (_, _, rc)| rc.strong_count() > 0); self.update(ip_info, port_forward).await } @@ -519,7 +593,7 @@ pub struct ForwardTable(pub BTreeMap); pub struct ForwardTarget { pub target: SocketAddrV4, pub target_prefix: u8, - pub filter: String, + pub reqs: String, } impl From<&InterfaceForwardState> for ForwardTable { @@ -530,16 +604,16 @@ impl From<&InterfaceForwardState> for ForwardTable { .iter() .flat_map(|entry| { entry - .filter + .targets .iter() .filter(|(_, (_, _, rc))| rc.strong_count() > 0) - .map(|(filter, (target, target_prefix, _))| { + .map(|(reqs, (target, target_prefix, _))| { ( entry.external, ForwardTarget { target: *target, target_prefix: *target_prefix, - filter: format!("{:#?}", filter), + reqs: format!("{reqs}"), }, ) }) @@ -558,16 +632,6 @@ enum InterfaceForwardCommand { DumpTable(oneshot::Sender), } -#[test] -fn test() { - use crate::net::gateway::SecureFilter; - - assert_ne!( - false.into_dyn(), - SecureFilter { secure: false }.into_dyn().into_dyn() - ); -} - pub struct InterfacePortForwardController { req: mpsc::UnboundedSender, _thread: NonDetachingJoinHandle<()>, @@ -617,7 +681,7 @@ impl InterfacePortForwardController { pub async fn add( &self, external: u16, - filter: DynInterfaceFilter, + reqs: ForwardRequirements, target: SocketAddrV4, target_prefix: u8, ) -> Result, Error> { @@ -629,7 +693,7 @@ impl InterfacePortForwardController { external, target, target_prefix, - filter, + reqs, rc, }, send, @@ -661,15 +725,21 @@ async fn forward( source: SocketAddrV4, target: SocketAddrV4, target_prefix: u8, + src_filter: Option<&SourceFilter>, ) -> Result<(), Error> { - Command::new("/usr/lib/startos/scripts/forward-port") - .env("sip", source.ip().to_string()) + let mut cmd = Command::new("/usr/lib/startos/scripts/forward-port"); + cmd.env("sip", source.ip().to_string()) .env("dip", target.ip().to_string()) .env("dprefix", target_prefix.to_string()) .env("sport", source.port().to_string()) - .env("dport", target.port().to_string()) - .invoke(ErrorKind::Network) - .await?; + .env("dport", target.port().to_string()); + if let Some(filter) = src_filter { + cmd.env("src_subnet", &filter.subnet); + if !filter.excluded.is_empty() { + cmd.env("excluded_src", &filter.excluded); + } + } + cmd.invoke(ErrorKind::Network).await?; Ok(()) } @@ -677,15 +747,21 @@ async fn unforward( source: SocketAddrV4, target: SocketAddrV4, target_prefix: u8, + src_filter: Option<&SourceFilter>, ) -> Result<(), Error> { - Command::new("/usr/lib/startos/scripts/forward-port") - .env("UNDO", "1") + let mut cmd = Command::new("/usr/lib/startos/scripts/forward-port"); + cmd.env("UNDO", "1") .env("sip", source.ip().to_string()) .env("dip", target.ip().to_string()) .env("dprefix", target_prefix.to_string()) .env("sport", source.port().to_string()) - .env("dport", target.port().to_string()) - .invoke(ErrorKind::Network) - .await?; + .env("dport", target.port().to_string()); + if let Some(filter) = src_filter { + cmd.env("src_subnet", &filter.subnet); + if !filter.excluded.is_empty() { + cmd.env("excluded_src", &filter.excluded); + } + } + cmd.invoke(ErrorKind::Network).await?; Ok(()) } diff --git a/core/src/net/gateway.rs b/core/src/net/gateway.rs index 688892f85..e9c2575ba 100644 --- a/core/src/net/gateway.rs +++ b/core/src/net/gateway.rs @@ -1,14 +1,11 @@ -use std::any::Any; use std::collections::{BTreeMap, BTreeSet, HashMap}; -use std::fmt; use std::future::Future; -use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV6}; -use std::sync::{Arc, Weak}; -use std::task::{Poll, ready}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::Arc; +use std::task::Poll; use std::time::Duration; use clap::Parser; -use futures::future::Either; use futures::{FutureExt, Stream, StreamExt, TryStreamExt}; use imbl::{OrdMap, OrdSet}; use imbl_value::InternedString; @@ -36,15 +33,14 @@ use crate::db::model::Database; use crate::db::model::public::{IpInfo, NetworkInterfaceInfo, NetworkInterfaceType}; use crate::net::forward::START9_BRIDGE_IFACE; use crate::net::gateway::device::DeviceProxy; -use crate::net::utils::ipv6_is_link_local; -use crate::net::web_server::{Accept, AcceptStream, Acceptor, MetadataVisitor}; +use crate::net::web_server::{Accept, AcceptStream, MetadataVisitor, TcpMetadata}; use crate::prelude::*; use crate::util::Invoke; use crate::util::collections::OrdMapIterMut; use crate::util::future::{NonDetachingJoinHandle, Until}; use crate::util::io::open_file; use crate::util::serde::{HandlerExtSerde, display_serializable}; -use crate::util::sync::{SyncMutex, Watch}; +use crate::util::sync::Watch; pub fn gateway_api() -> ParentHandler { ParentHandler::new() @@ -838,7 +834,6 @@ pub struct NetworkInterfaceWatcher { activated: Watch>, ip_info: Watch>, _watcher: NonDetachingJoinHandle<()>, - listeners: SyncMutex>>, } impl NetworkInterfaceWatcher { pub fn new( @@ -858,7 +853,6 @@ impl NetworkInterfaceWatcher { watcher(ip_info, activated).await }) .into(), - listeners: SyncMutex::new(BTreeMap::new()), } } @@ -885,51 +879,6 @@ impl NetworkInterfaceWatcher { pub fn ip_info(&self) -> OrdMap { self.ip_info.read() } - - pub fn bind(&self, bind: B, port: u16) -> Result, Error> { - let arc = Arc::new(()); - self.listeners.mutate(|l| { - if l.get(&port).filter(|w| w.strong_count() > 0).is_some() { - return Err(Error::new( - std::io::Error::from_raw_os_error(libc::EADDRINUSE), - ErrorKind::Network, - )); - } - l.insert(port, Arc::downgrade(&arc)); - Ok(()) - })?; - let ip_info = self.ip_info.clone_unseen(); - Ok(NetworkInterfaceListener { - _arc: arc, - ip_info, - listeners: ListenerMap::new(bind, port), - }) - } - - pub fn upgrade_listener( - &self, - SelfContainedNetworkInterfaceListener { - mut listener, - .. - }: SelfContainedNetworkInterfaceListener, - ) -> Result, Error> { - let port = listener.listeners.port; - let arc = &listener._arc; - self.listeners.mutate(|l| { - if l.get(&port).filter(|w| w.strong_count() > 0).is_some() { - return Err(Error::new( - std::io::Error::from_raw_os_error(libc::EADDRINUSE), - ErrorKind::Network, - )); - } - l.insert(port, Arc::downgrade(arc)); - Ok(()) - })?; - let ip_info = self.ip_info.clone_unseen(); - ip_info.mark_changed(); - listener.change_ip_info_source(ip_info); - Ok(listener) - } } pub struct NetworkInterfaceController { @@ -1237,235 +1186,6 @@ impl NetworkInterfaceController { } } -pub trait InterfaceFilter: Any + Clone + std::fmt::Debug + Eq + Ord + Send + Sync { - fn filter(&self, id: &GatewayId, info: &NetworkInterfaceInfo) -> bool; - fn eq(&self, other: &dyn Any) -> bool { - Some(self) == other.downcast_ref::() - } - fn cmp(&self, other: &dyn Any) -> std::cmp::Ordering { - match (self as &dyn Any).type_id().cmp(&other.type_id()) { - std::cmp::Ordering::Equal => { - std::cmp::Ord::cmp(self, other.downcast_ref::().unwrap()) - } - ord => ord, - } - } - fn as_any(&self) -> &dyn Any { - self - } - fn into_dyn(self) -> DynInterfaceFilter { - DynInterfaceFilter::new(self) - } -} - -impl InterfaceFilter for bool { - fn filter(&self, _: &GatewayId, _: &NetworkInterfaceInfo) -> bool { - *self - } -} - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct TypeFilter(pub NetworkInterfaceType); -impl InterfaceFilter for TypeFilter { - fn filter(&self, _: &GatewayId, info: &NetworkInterfaceInfo) -> bool { - info.ip_info.as_ref().and_then(|i| i.device_type) == Some(self.0) - } -} - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct IdFilter(pub GatewayId); -impl InterfaceFilter for IdFilter { - fn filter(&self, id: &GatewayId, _: &NetworkInterfaceInfo) -> bool { - id == &self.0 - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct PublicFilter { - pub public: bool, -} -impl InterfaceFilter for PublicFilter { - fn filter(&self, _: &GatewayId, info: &NetworkInterfaceInfo) -> bool { - self.public == info.public() - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct SecureFilter { - pub secure: bool, -} -impl InterfaceFilter for SecureFilter { - fn filter(&self, _: &GatewayId, info: &NetworkInterfaceInfo) -> bool { - self.secure || info.secure() - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct AndFilter(pub A, pub B); -impl InterfaceFilter for AndFilter { - fn filter(&self, id: &GatewayId, info: &NetworkInterfaceInfo) -> bool { - self.0.filter(id, info) && self.1.filter(id, info) - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct OrFilter(pub A, pub B); -impl InterfaceFilter for OrFilter { - fn filter(&self, id: &GatewayId, info: &NetworkInterfaceInfo) -> bool { - self.0.filter(id, info) || self.1.filter(id, info) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub struct AnyFilter(pub BTreeSet); -impl InterfaceFilter for AnyFilter { - fn filter(&self, id: &GatewayId, info: &NetworkInterfaceInfo) -> bool { - self.0.iter().any(|f| InterfaceFilter::filter(f, id, info)) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub struct AllFilter(pub BTreeSet); -impl InterfaceFilter for AllFilter { - fn filter(&self, id: &GatewayId, info: &NetworkInterfaceInfo) -> bool { - self.0.iter().all(|f| InterfaceFilter::filter(f, id, info)) - } -} - -pub trait DynInterfaceFilterT: std::fmt::Debug + Any + Send + Sync { - fn filter(&self, id: &GatewayId, info: &NetworkInterfaceInfo) -> bool; - fn eq(&self, other: &dyn Any) -> bool; - fn cmp(&self, other: &dyn Any) -> std::cmp::Ordering; - fn as_any(&self) -> &dyn Any; -} -impl DynInterfaceFilterT for T { - fn filter(&self, id: &GatewayId, info: &NetworkInterfaceInfo) -> bool { - InterfaceFilter::filter(self, id, info) - } - fn eq(&self, other: &dyn Any) -> bool { - InterfaceFilter::eq(self, other) - } - fn cmp(&self, other: &dyn Any) -> std::cmp::Ordering { - InterfaceFilter::cmp(self, other) - } - fn as_any(&self) -> &dyn Any { - InterfaceFilter::as_any(self) - } -} - -#[test] -fn test_interface_filter_eq() { - let dyn_t = true.into_dyn(); - assert!(DynInterfaceFilterT::eq( - &dyn_t, - DynInterfaceFilterT::as_any(&true), - )) -} - -#[derive(Clone, Debug)] -pub struct DynInterfaceFilter(Arc); -impl InterfaceFilter for DynInterfaceFilter { - fn filter(&self, id: &GatewayId, info: &NetworkInterfaceInfo) -> bool { - self.0.filter(id, info) - } - fn eq(&self, other: &dyn Any) -> bool { - self.0.eq(other) - } - fn cmp(&self, other: &dyn Any) -> std::cmp::Ordering { - self.0.cmp(other) - } - fn as_any(&self) -> &dyn Any { - self.0.as_any() - } - fn into_dyn(self) -> DynInterfaceFilter { - self - } -} -impl DynInterfaceFilter { - fn new(value: T) -> Self { - Self(Arc::new(value)) - } -} -impl PartialEq for DynInterfaceFilter { - fn eq(&self, other: &Self) -> bool { - DynInterfaceFilterT::eq(&*self.0, DynInterfaceFilterT::as_any(&*other.0)) - } -} -impl Eq for DynInterfaceFilter {} -impl PartialOrd for DynInterfaceFilter { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.0.cmp(other.0.as_any())) - } -} -impl Ord for DynInterfaceFilter { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.0.cmp(other.0.as_any()) - } -} - -struct ListenerMap { - prev_filter: DynInterfaceFilter, - bind: B, - port: u16, - listeners: BTreeMap, -} -impl ListenerMap { - fn new(bind: B, port: u16) -> Self { - Self { - prev_filter: false.into_dyn(), - bind, - port, - listeners: BTreeMap::new(), - } - } - - #[instrument(skip(self))] - fn update( - &mut self, - ip_info: &OrdMap, - filter: &impl InterfaceFilter, - ) -> Result<(), Error> { - let mut keep = BTreeSet::::new(); - for (_, info) in ip_info - .iter() - .filter(|(id, info)| filter.filter(*id, *info)) - { - if let Some(ip_info) = &info.ip_info { - for ipnet in &ip_info.subnets { - let addr = match ipnet.addr() { - IpAddr::V6(ip6) => SocketAddrV6::new( - ip6, - self.port, - 0, - if ipv6_is_link_local(ip6) { - ip_info.scope_id - } else { - 0 - }, - ) - .into(), - ip => SocketAddr::new(ip, self.port), - }; - keep.insert(addr); - if !self.listeners.contains_key(&addr) { - self.listeners.insert(addr, self.bind.bind(addr)?); - } - } - } - } - self.listeners.retain(|key, _| keep.contains(key)); - self.prev_filter = filter.clone().into_dyn(); - Ok(()) - } - fn poll_accept( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll::Metadata, AcceptStream), Error>> { - let (metadata, stream) = ready!(self.listeners.poll_accept(cx)?); - Poll::Ready(Ok((metadata.key, metadata.inner, stream))) - } -} - pub fn lookup_info_by_addr( ip_info: &OrdMap, addr: SocketAddr, @@ -1477,28 +1197,6 @@ pub fn lookup_info_by_addr( }) } -pub trait Bind { - type Accept: Accept; - fn bind(&mut self, addr: SocketAddr) -> Result; -} - -#[derive(Clone, Copy, Default)] -pub struct BindTcp; -impl Bind for BindTcp { - type Accept = TcpListener; - fn bind(&mut self, addr: SocketAddr) -> Result { - TcpListener::from_std( - mio::net::TcpListener::bind(addr) - .with_kind(ErrorKind::Network)? - .into(), - ) - .with_kind(ErrorKind::Network) - } -} - -pub trait FromGatewayInfo { - fn from_gateway_info(id: &GatewayId, info: &NetworkInterfaceInfo) -> Self; -} #[derive(Clone, Debug)] pub struct GatewayInfo { pub id: GatewayId, @@ -1509,202 +1207,88 @@ impl Visit for GatewayInfo { visitor.visit(self) } } -impl FromGatewayInfo for GatewayInfo { - fn from_gateway_info(id: &GatewayId, info: &NetworkInterfaceInfo) -> Self { - Self { - id: id.clone(), - info: info.clone(), - } - } -} -pub struct NetworkInterfaceListener { - pub ip_info: Watch>, - listeners: ListenerMap, - _arc: Arc<()>, -} -impl NetworkInterfaceListener { - pub(super) fn new( - mut ip_info: Watch>, - bind: B, - port: u16, - ) -> Self { - ip_info.mark_unseen(); - Self { - ip_info, - listeners: ListenerMap::new(bind, port), - _arc: Arc::new(()), - } - } - - pub fn port(&self) -> u16 { - self.listeners.port - } - - #[cfg_attr(feature = "unstable", inline(never))] - pub fn poll_accept( - &mut self, - cx: &mut std::task::Context<'_>, - filter: &impl InterfaceFilter, - ) -> Poll::Metadata, AcceptStream), Error>> { - while self.ip_info.poll_changed(cx).is_ready() - || !DynInterfaceFilterT::eq(&self.listeners.prev_filter, filter.as_any()) - { - self.ip_info - .peek_and_mark_seen(|ip_info| self.listeners.update(ip_info, filter))?; - } - let (addr, inner, stream) = ready!(self.listeners.poll_accept(cx)?); - Poll::Ready(Ok(( - self.ip_info - .peek(|ip_info| { - lookup_info_by_addr(ip_info, addr) - .map(|(id, info)| M::from_gateway_info(id, info)) - }) - .or_not_found(lazy_format!("gateway for {addr}"))?, - inner, - stream, - ))) - } - - pub fn change_ip_info_source( - &mut self, - mut ip_info: Watch>, - ) { - ip_info.mark_unseen(); - self.ip_info = ip_info; - } - - pub async fn accept( - &mut self, - filter: &impl InterfaceFilter, - ) -> Result<(M, ::Metadata, AcceptStream), Error> { - futures::future::poll_fn(|cx| self.poll_accept(cx, filter)).await - } - - pub fn check_filter(&self) -> impl FnOnce(SocketAddr, &DynInterfaceFilter) -> bool + 'static { - let ip_info = self.ip_info.clone(); - move |addr, filter| { - ip_info.peek(|i| { - lookup_info_by_addr(i, addr).map_or(false, |(id, info)| { - InterfaceFilter::filter(filter, id, info) - }) - }) - } - } -} - -#[derive(VisitFields)] -pub struct NetworkInterfaceListenerAcceptMetadata { - pub inner: ::Metadata, +/// Metadata for connections accepted by WildcardListener or VHostBindListener. +#[derive(Clone, Debug, VisitFields)] +pub struct NetworkInterfaceListenerAcceptMetadata { + pub inner: TcpMetadata, pub info: GatewayInfo, } -impl fmt::Debug for NetworkInterfaceListenerAcceptMetadata { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("NetworkInterfaceListenerAcceptMetadata") - .field("inner", &self.inner) - .field("info", &self.info) - .finish() - } -} -impl Clone for NetworkInterfaceListenerAcceptMetadata -where - ::Metadata: Clone, -{ - fn clone(&self) -> Self { - Self { - inner: self.inner.clone(), - info: self.info.clone(), - } - } -} -impl Visit for NetworkInterfaceListenerAcceptMetadata -where - B: Bind, - ::Metadata: Visit + Clone + Send + Sync + 'static, - V: MetadataVisitor, -{ +impl Visit for NetworkInterfaceListenerAcceptMetadata { fn visit(&self, visitor: &mut V) -> V::Result { self.visit_fields(visitor).collect() } } -impl Accept for NetworkInterfaceListener { - type Metadata = NetworkInterfaceListenerAcceptMetadata; +/// A simple TCP listener on 0.0.0.0:port that looks up GatewayInfo from the +/// connection's local address on each accepted connection. +pub struct WildcardListener { + listener: TcpListener, + ip_info: Watch>, + /// Handle to the self-contained watcher task started in `new()`. + /// Dropped (and thus aborted) when `set_ip_info` replaces the ip_info source. + _watcher: Option>, +} +impl WildcardListener { + pub fn new(port: u16) -> Result { + let listener = TcpListener::from_std( + mio::net::TcpListener::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port)) + .with_kind(ErrorKind::Network)? + .into(), + ) + .with_kind(ErrorKind::Network)?; + let ip_info = Watch::new(OrdMap::new()); + let watcher_handle = + tokio::spawn(watcher(ip_info.clone(), Watch::new(BTreeMap::new()))).into(); + Ok(Self { + listener, + ip_info, + _watcher: Some(watcher_handle), + }) + } + + /// Replace the ip_info source with the one from the NetworkInterfaceController. + /// Aborts the self-contained watcher task. + pub fn set_ip_info(&mut self, ip_info: Watch>) { + self.ip_info = ip_info; + self._watcher = None; + } +} +impl Accept for WildcardListener { + type Metadata = NetworkInterfaceListenerAcceptMetadata; fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll> { - NetworkInterfaceListener::poll_accept(self, cx, &true).map(|res| { - res.map(|(info, inner, stream)| { - ( - NetworkInterfaceListenerAcceptMetadata { inner, info }, - stream, - ) - }) - }) - } -} - -pub struct SelfContainedNetworkInterfaceListener { - _watch_thread: NonDetachingJoinHandle<()>, - listener: NetworkInterfaceListener, -} -impl SelfContainedNetworkInterfaceListener { - pub fn bind(bind: B, port: u16) -> Self { - let ip_info = Watch::new(OrdMap::new()); - let _watch_thread = - tokio::spawn(watcher(ip_info.clone(), Watch::new(BTreeMap::new()))).into(); - Self { - _watch_thread, - listener: NetworkInterfaceListener::new(ip_info, bind, port), + if let Poll::Ready((stream, peer_addr)) = TcpListener::poll_accept(&self.listener, cx)? { + if let Err(e) = socket2::SockRef::from(&stream).set_keepalive(true) { + tracing::error!("Failed to set tcp keepalive: {e}"); + tracing::debug!("{e:?}"); + } + let local_addr = stream.local_addr()?; + let info = self + .ip_info + .peek(|ip_info| { + lookup_info_by_addr(ip_info, local_addr).map(|(id, info)| GatewayInfo { + id: id.clone(), + info: info.clone(), + }) + }) + .unwrap_or_else(|| GatewayInfo { + id: InternedString::from_static("").into(), + info: NetworkInterfaceInfo::default(), + }); + return Poll::Ready(Ok(( + NetworkInterfaceListenerAcceptMetadata { + inner: TcpMetadata { + local_addr, + peer_addr, + }, + info, + }, + Box::pin(stream), + ))); } + Poll::Pending } } -impl Accept for SelfContainedNetworkInterfaceListener { - type Metadata = as Accept>::Metadata; - fn poll_accept( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - Accept::poll_accept(&mut self.listener, cx) - } -} - -pub type UpgradableListener = - Option, NetworkInterfaceListener>>; - -impl Acceptor> -where - B: Bind + Send + Sync + 'static, - B::Accept: Send + Sync, -{ - pub fn bind_upgradable(listener: SelfContainedNetworkInterfaceListener) -> Self { - Self::new(Some(Either::Left(listener))) - } -} - -#[test] -fn test_filter() { - let wg1 = "wg1".parse::().unwrap(); - assert!(!InterfaceFilter::filter( - &AndFilter(IdFilter(wg1.clone()), PublicFilter { public: false }).into_dyn(), - &wg1, - &NetworkInterfaceInfo { - name: None, - public: None, - secure: None, - ip_info: Some(Arc::new(IpInfo { - name: "".into(), - scope_id: 3, - device_type: Some(NetworkInterfaceType::Wireguard), - subnets: ["10.59.0.2/24".parse::().unwrap()] - .into_iter() - .collect(), - lan_ip: Default::default(), - wan_ip: None, - ntp_servers: Default::default(), - dns_servers: Default::default(), - })), - }, - )); -} diff --git a/core/src/net/host/binding.rs b/core/src/net/host/binding.rs index 3c78c4338..8db806399 100644 --- a/core/src/net/host/binding.rs +++ b/core/src/net/host/binding.rs @@ -8,17 +8,15 @@ use serde::{Deserialize, Serialize}; use ts_rs::TS; use crate::context::{CliContext, RpcContext}; -use crate::db::model::public::NetworkInterfaceInfo; use crate::db::prelude::Map; use crate::net::forward::AvailablePorts; -use crate::net::gateway::InterfaceFilter; use crate::net::host::HostApiKind; use crate::net::service_interface::HostnameInfo; use crate::net::vhost::AlpnInfo; use crate::prelude::*; use crate::util::FromStrParser; use crate::util::serde::{HandlerExtSerde, display_serializable}; -use crate::{GatewayId, HostId}; +use crate::HostId; #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, TS)] #[ts(export)] @@ -51,9 +49,9 @@ impl FromStr for BindId { #[ts(export)] #[model = "Model"] pub struct DerivedAddressInfo { - /// User-controlled: private-gateway addresses the user has disabled + /// User-controlled: private addresses the user has disabled pub private_disabled: BTreeSet, - /// User-controlled: public-gateway addresses the user has enabled + /// User-controlled: public addresses the user has enabled pub public_enabled: BTreeSet, /// COMPUTED: NetServiceData::update — all possible addresses for this binding pub possible: BTreeSet, @@ -76,26 +74,6 @@ impl DerivedAddressInfo { .collect() } - /// Derive a gateway-level InterfaceFilter from the enabled addresses. - /// A gateway passes the filter if it has any enabled address for this binding. - pub fn gateway_filter(&self) -> AddressFilter { - let enabled_gateways: BTreeSet = self - .enabled() - .into_iter() - .map(|h| h.gateway.id.clone()) - .collect(); - AddressFilter(enabled_gateways) - } -} - -/// Gateway-level filter derived from DerivedAddressInfo. -/// Passes if the gateway has at least one enabled address. -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub struct AddressFilter(pub BTreeSet); -impl InterfaceFilter for AddressFilter { - fn filter(&self, id: &GatewayId, info: &NetworkInterfaceInfo) -> bool { - info.ip_info.is_some() && self.0.contains(id) - } } #[derive(Debug, Default, Deserialize, Serialize, HasModel, TS)] @@ -145,12 +123,6 @@ pub struct NetInfo { pub assigned_port: Option, pub assigned_ssl_port: Option, } -impl InterfaceFilter for NetInfo { - fn filter(&self, _id: &GatewayId, info: &NetworkInterfaceInfo) -> bool { - info.ip_info.is_some() - } -} - impl BindInfo { pub fn new(available_ports: &mut AvailablePorts, options: BindOptions) -> Result { let mut assigned_port = None; diff --git a/core/src/net/net_controller.rs b/core/src/net/net_controller.rs index 56c30ed4e..37c63f333 100644 --- a/core/src/net/net_controller.rs +++ b/core/src/net/net_controller.rs @@ -16,11 +16,10 @@ use crate::db::model::public::NetworkInterfaceType; use crate::error::ErrorCollection; use crate::hostname::Hostname; use crate::net::dns::DnsController; -use crate::net::forward::{InterfacePortForwardController, START9_BRIDGE_IFACE, add_iptables_rule}; -use crate::net::gateway::{ - AndFilter, AnyFilter, DynInterfaceFilter, IdFilter, InterfaceFilter, - NetworkInterfaceController, OrFilter, PublicFilter, SecureFilter, +use crate::net::forward::{ + ForwardRequirements, InterfacePortForwardController, START9_BRIDGE_IFACE, add_iptables_rule, }; +use crate::net::gateway::NetworkInterfaceController; use crate::net::host::address::HostAddress; use crate::net::host::binding::{AddSslOptions, BindId, BindOptions}; use crate::net::host::{Host, Hosts, host_for}; @@ -31,7 +30,7 @@ use crate::net::vhost::{AlpnInfo, DynVHostTarget, ProxyTarget, VHostController}; use crate::prelude::*; use crate::service::effects::callbacks::ServiceCallbacks; use crate::util::serde::MaybeUtf8String; -use crate::{HOST_IP, HostId, OptionExt, PackageId}; +use crate::{GatewayId, HOST_IP, HostId, OptionExt, PackageId}; pub struct NetController { pub(crate) db: TypedPatchDb, @@ -161,7 +160,7 @@ impl NetController { #[derive(Default, Debug)] struct HostBinds { - forwards: BTreeMap)>, + forwards: BTreeMap)>, vhosts: BTreeMap<(Option, u16), (ProxyTarget, Arc<()>)>, private_dns: BTreeMap>, } @@ -257,213 +256,36 @@ impl NetServiceData { id: HostId, mut host: Host, ) -> Result<(), Error> { - let mut forwards: BTreeMap = BTreeMap::new(); + let mut forwards: BTreeMap = BTreeMap::new(); let mut vhosts: BTreeMap<(Option, u16), ProxyTarget> = BTreeMap::new(); let mut private_dns: BTreeSet = BTreeSet::new(); let binds = self.binds.entry(id.clone()).or_default(); let peek = ctrl.db.peek().await; - - // LAN let server_info = peek.as_public().as_server_info(); let net_ifaces = ctrl.net_iface.watcher.ip_info(); let hostname = server_info.as_hostname().de()?; let host_addresses: Vec<_> = host.addresses().collect(); - for (port, bind) in host.bindings.iter_mut() { + + // Collect private DNS entries (domains without public config) + for HostAddress { + address, public, .. + } in &host_addresses + { + if public.is_none() { + private_dns.insert(address.clone()); + } + } + + // ── Phase 1: Compute possible addresses ── + for (_port, bind) in host.bindings.iter_mut() { if !bind.enabled { continue; } if bind.net.assigned_port.is_none() && bind.net.assigned_ssl_port.is_none() { continue; } - let mut hostnames = BTreeSet::new(); - let mut gw_filter = AnyFilter( - [PublicFilter { public: false }.into_dyn()] - .into_iter() - .chain( - bind.addresses - .public_enabled - .iter() - .map(|a| a.gateway.id.clone()) - .collect::>() - .into_iter() - .map(IdFilter) - .map(InterfaceFilter::into_dyn), - ) - .collect(), - ); - if let Some(ssl) = &bind.options.add_ssl { - let external = bind - .net - .assigned_ssl_port - .or_not_found("assigned ssl port")?; - let addr = (self.ip, *port).into(); - let connect_ssl = if let Some(alpn) = ssl.alpn.clone() { - Err(alpn) - } else { - if bind.options.secure.as_ref().map_or(false, |s| s.ssl) { - Ok(()) - } else { - Err(AlpnInfo::Reflect) - } - }; - for hostname in ctrl.server_hostnames.iter().cloned() { - vhosts.insert( - (hostname, external), - ProxyTarget { - filter: gw_filter.clone().into_dyn(), - acme: None, - addr, - add_x_forwarded_headers: ssl.add_x_forwarded_headers, - connect_ssl: connect_ssl - .clone() - .map(|_| ctrl.tls_client_config.clone()), - }, // TODO: allow public traffic? - ); - } - for HostAddress { - address, - public, - private, - } in host_addresses.iter().cloned() - { - if hostnames.insert(address.clone()) { - let address = Some(address.clone()); - if ssl.preferred_external_port == 443 { - if let Some(public) = &public { - vhosts.insert( - (address.clone(), 5443), - ProxyTarget { - filter: AndFilter( - bind.net.clone(), - AndFilter( - IdFilter(public.gateway.clone()), - PublicFilter { public: false }, - ), - ) - .into_dyn(), - acme: public.acme.clone(), - addr, - add_x_forwarded_headers: ssl.add_x_forwarded_headers, - connect_ssl: connect_ssl - .clone() - .map(|_| ctrl.tls_client_config.clone()), - }, - ); - vhosts.insert( - (address.clone(), 443), - ProxyTarget { - filter: AndFilter( - bind.net.clone(), - if private { - OrFilter( - IdFilter(public.gateway.clone()), - PublicFilter { public: false }, - ) - .into_dyn() - } else { - AndFilter( - IdFilter(public.gateway.clone()), - PublicFilter { public: true }, - ) - .into_dyn() - }, - ) - .into_dyn(), - acme: public.acme.clone(), - addr, - add_x_forwarded_headers: ssl.add_x_forwarded_headers, - connect_ssl: connect_ssl - .clone() - .map(|_| ctrl.tls_client_config.clone()), - }, - ); - } else { - vhosts.insert( - (address.clone(), 443), - ProxyTarget { - filter: AndFilter( - bind.net.clone(), - PublicFilter { public: false }, - ) - .into_dyn(), - acme: None, - addr, - add_x_forwarded_headers: ssl.add_x_forwarded_headers, - connect_ssl: connect_ssl - .clone() - .map(|_| ctrl.tls_client_config.clone()), - }, - ); - } - } else { - if let Some(public) = public { - vhosts.insert( - (address.clone(), external), - ProxyTarget { - filter: AndFilter( - bind.net.clone(), - if private { - OrFilter( - IdFilter(public.gateway.clone()), - PublicFilter { public: false }, - ) - .into_dyn() - } else { - IdFilter(public.gateway.clone()).into_dyn() - }, - ) - .into_dyn(), - acme: public.acme.clone(), - addr, - add_x_forwarded_headers: ssl.add_x_forwarded_headers, - connect_ssl: connect_ssl - .clone() - .map(|_| ctrl.tls_client_config.clone()), - }, - ); - } else { - vhosts.insert( - (address.clone(), external), - ProxyTarget { - filter: AndFilter( - bind.net.clone(), - PublicFilter { public: false }, - ) - .into_dyn(), - acme: None, - addr, - add_x_forwarded_headers: ssl.add_x_forwarded_headers, - connect_ssl: connect_ssl - .clone() - .map(|_| ctrl.tls_client_config.clone()), - }, - ); - } - } - } - } - } - if bind - .options - .secure - .map_or(true, |s| !(s.ssl && bind.options.add_ssl.is_some())) - { - let external = bind.net.assigned_port.or_not_found("assigned lan port")?; - forwards.insert( - external, - ( - SocketAddrV4::new(self.ip, *port), - AndFilter( - SecureFilter { - secure: bind.options.secure.is_some(), - }, - bind.net.clone(), - ) - .into_dyn(), - ), - ); - } + bind.addresses.possible.clear(); for (gateway_id, info) in net_ifaces .iter() @@ -472,7 +294,7 @@ impl NetServiceData { !matches!(i.device_type, Some(NetworkInterfaceType::Bridge)) }) }) - .filter(|(id, info)| bind.net.filter(id, info)) + .filter(|(_, info)| info.ip_info.is_some()) { let gateway = GatewayInfo { id: gateway_id.clone(), @@ -488,6 +310,7 @@ impl NetServiceData { !(s.ssl && bind.options.add_ssl.is_some()) || info.secure() }) }); + // .local addresses (private only, non-public, non-wireguard gateways) if !info.public() && info.ip_info.as_ref().map_or(false, |i| { i.device_type != Some(NetworkInterfaceType::Wireguard) @@ -506,46 +329,39 @@ impl NetServiceData { }, }); } + // Domain addresses for HostAddress { address, public, private, } in host_addresses.iter().cloned() { - if public.is_none() { - private_dns.insert(address.clone()); - } let private = private && !info.public(); - let public = public.as_ref().map_or(false, |p| &p.gateway == gateway_id); + let public = + public.as_ref().map_or(false, |p| &p.gateway == gateway_id); if public || private { - if bind + let (domain_port, domain_ssl_port) = if bind .options .add_ssl .as_ref() .map_or(false, |ssl| ssl.preferred_external_port == 443) { - bind.addresses.possible.insert(HostnameInfo { - gateway: gateway.clone(), - public, - hostname: IpHostname::Domain { - value: address.clone(), - port: None, - ssl_port: Some(443), - }, - }); + (None, Some(443)) } else { - bind.addresses.possible.insert(HostnameInfo { - gateway: gateway.clone(), - public, - hostname: IpHostname::Domain { - value: address.clone(), - port, - ssl_port: bind.net.assigned_ssl_port, - }, - }); - } + (port, bind.net.assigned_ssl_port) + }; + bind.addresses.possible.insert(HostnameInfo { + gateway: gateway.clone(), + public, + hostname: IpHostname::Domain { + value: address.clone(), + port: domain_port, + ssl_port: domain_ssl_port, + }, + }); } } + // IP addresses if let Some(ip_info) = &info.ip_info { let public = info.public(); if let Some(wan_ip) = ip_info.wan_ip { @@ -592,6 +408,137 @@ impl NetServiceData { } } + // ── Phase 2: Build controller entries from enabled addresses ── + for (port, bind) in host.bindings.iter() { + if !bind.enabled { + continue; + } + if bind.net.assigned_port.is_none() && bind.net.assigned_ssl_port.is_none() { + continue; + } + + let enabled_addresses = bind.addresses.enabled(); + let addr: SocketAddr = (self.ip, *port).into(); + + // SSL vhosts + if let Some(ssl) = &bind.options.add_ssl { + let connect_ssl = if let Some(alpn) = ssl.alpn.clone() { + Err(alpn) + } else if bind.options.secure.as_ref().map_or(false, |s| s.ssl) { + Ok(()) + } else { + Err(AlpnInfo::Reflect) + }; + + if let Some(assigned_ssl_port) = bind.net.assigned_ssl_port { + // Collect private IPs from enabled private addresses' gateways + let server_private_ips: BTreeSet = enabled_addresses + .iter() + .filter(|a| !a.public) + .filter_map(|a| { + net_ifaces + .get(&a.gateway.id) + .and_then(|info| info.ip_info.as_ref()) + }) + .flat_map(|ip_info| ip_info.subnets.iter().map(|s| s.addr())) + .collect(); + + // Server hostname vhosts (on assigned_ssl_port) — private only + if !server_private_ips.is_empty() { + for hostname in ctrl.server_hostnames.iter().cloned() { + vhosts.insert( + (hostname, assigned_ssl_port), + ProxyTarget { + public: BTreeSet::new(), + private: server_private_ips.clone(), + acme: None, + addr, + add_x_forwarded_headers: ssl.add_x_forwarded_headers, + connect_ssl: connect_ssl + .clone() + .map(|_| ctrl.tls_client_config.clone()), + }, + ); + } + } + } + + // Domain vhosts: group by (domain, ssl_port), merge public/private sets + for addr_info in &enabled_addresses { + if let IpHostname::Domain { + value: domain, + ssl_port: Some(domain_ssl_port), + .. + } = &addr_info.hostname + { + let key = (Some(domain.clone()), *domain_ssl_port); + let target = vhosts.entry(key).or_insert_with(|| ProxyTarget { + public: BTreeSet::new(), + private: BTreeSet::new(), + acme: host_addresses + .iter() + .find(|a| &a.address == domain) + .and_then(|a| a.public.as_ref()) + .and_then(|p| p.acme.clone()), + addr, + add_x_forwarded_headers: ssl.add_x_forwarded_headers, + connect_ssl: connect_ssl + .clone() + .map(|_| ctrl.tls_client_config.clone()), + }); + if addr_info.public { + target.public.insert(addr_info.gateway.id.clone()); + } else { + // Add interface IPs for this gateway to private set + if let Some(info) = net_ifaces.get(&addr_info.gateway.id) { + if let Some(ip_info) = &info.ip_info { + for subnet in &ip_info.subnets { + target.private.insert(subnet.addr()); + } + } + } + } + } + } + } + + // Non-SSL forwards + if bind + .options + .secure + .map_or(true, |s| !(s.ssl && bind.options.add_ssl.is_some())) + { + let external = bind.net.assigned_port.or_not_found("assigned lan port")?; + let fwd_public: BTreeSet = enabled_addresses + .iter() + .filter(|a| a.public) + .map(|a| a.gateway.id.clone()) + .collect(); + let fwd_private: BTreeSet = enabled_addresses + .iter() + .filter(|a| !a.public) + .filter_map(|a| { + net_ifaces + .get(&a.gateway.id) + .and_then(|i| i.ip_info.as_ref()) + }) + .flat_map(|ip| ip.subnets.iter().map(|s| s.addr())) + .collect(); + forwards.insert( + external, + ( + SocketAddrV4::new(self.ip, *port), + ForwardRequirements { + public_gateways: fwd_public, + private_ips: fwd_private, + secure: bind.options.secure.is_some(), + }, + ), + ); + } + } + + // ── Phase 3: Reconcile ── let all = binds .forwards .keys() @@ -600,8 +547,8 @@ impl NetServiceData { .collect::>(); for external in all { let mut prev = binds.forwards.remove(&external); - if let Some((internal, filter)) = forwards.remove(&external) { - prev = prev.filter(|(i, f, _)| i == &internal && *f == filter); + if let Some((internal, reqs)) = forwards.remove(&external) { + prev = prev.filter(|(i, r, _)| i == &internal && *r == reqs); binds.forwards.insert( external, if let Some(prev) = prev { @@ -609,11 +556,11 @@ impl NetServiceData { } else { ( internal, - filter.clone(), + reqs.clone(), ctrl.forward .add( external, - filter, + reqs, internal, net_ifaces .iter() diff --git a/core/src/net/vhost.rs b/core/src/net/vhost.rs index 4996ca937..9023576c3 100644 --- a/core/src/net/vhost.rs +++ b/core/src/net/vhost.rs @@ -1,19 +1,19 @@ use std::any::Any; use std::collections::{BTreeMap, BTreeSet}; use std::fmt; -use std::net::{IpAddr, SocketAddr}; +use std::net::{IpAddr, SocketAddr, SocketAddrV6}; use std::sync::{Arc, Weak}; use std::task::{Poll, ready}; -use std::time::Duration; use async_acme::acme::ACME_TLS_ALPN_NAME; use color_eyre::eyre::eyre; use futures::FutureExt; use futures::future::BoxFuture; +use imbl::OrdMap; use imbl_value::{InOMap, InternedString}; use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn}; use serde::{Deserialize, Serialize}; -use tokio::net::TcpStream; +use tokio::net::{TcpListener, TcpStream}; use tokio_rustls::TlsConnector; use tokio_rustls::rustls::crypto::CryptoProvider; use tokio_rustls::rustls::pki_types::ServerName; @@ -23,28 +23,28 @@ use tracing::instrument; use ts_rs::TS; use visit_rs::Visit; -use crate::ResultExt; use crate::context::{CliContext, RpcContext}; use crate::db::model::Database; -use crate::db::model::public::AcmeSettings; +use crate::db::model::public::{AcmeSettings, NetworkInterfaceInfo}; use crate::db::{DbAccessByKey, DbAccessMut}; use crate::net::acme::{ AcmeCertStore, AcmeProvider, AcmeTlsAlpnCache, AcmeTlsHandler, GetAcmeProvider, }; use crate::net::gateway::{ - AnyFilter, BindTcp, DynInterfaceFilter, GatewayInfo, InterfaceFilter, - NetworkInterfaceController, NetworkInterfaceListener, + GatewayInfo, NetworkInterfaceController, NetworkInterfaceListenerAcceptMetadata, }; use crate::net::ssl::{CertStore, RootCaTlsHandler}; use crate::net::tls::{ ChainedHandler, TlsHandlerWrapper, TlsListener, TlsMetadata, WrapTlsHandler, }; +use crate::net::utils::ipv6_is_link_local; use crate::net::web_server::{Accept, AcceptStream, ExtractVisitor, TcpMetadata, extract}; use crate::prelude::*; use crate::util::collections::EqSet; use crate::util::future::{NonDetachingJoinHandle, WeakFuture}; use crate::util::serde::{HandlerExtSerde, MaybeUtf8String, display_serializable}; use crate::util::sync::{SyncMutex, Watch}; +use crate::{GatewayId, ResultExt}; pub fn vhost_api() -> ParentHandler { ParentHandler::new().subcommand( @@ -93,7 +93,7 @@ pub struct VHostController { interfaces: Arc, crypto_provider: Arc, acme_cache: AcmeTlsAlpnCache, - servers: SyncMutex>>, + servers: SyncMutex>>, } impl VHostController { pub fn new( @@ -114,14 +114,22 @@ impl VHostController { &self, hostname: Option, external: u16, - target: DynVHostTarget, + target: DynVHostTarget, ) -> Result, Error> { self.servers.mutate(|writable| { let server = if let Some(server) = writable.remove(&external) { server } else { + let bind_reqs = Watch::new(VHostBindRequirements::default()); + let listener = VHostBindListener { + ip_info: self.interfaces.watcher.subscribe(), + port: external, + bind_reqs: bind_reqs.clone_unseen(), + listeners: BTreeMap::new(), + }; VHostServer::new( - self.interfaces.watcher.bind(BindTcp, external)?, + listener, + bind_reqs, self.db.clone(), self.crypto_provider.clone(), self.acme_cache.clone(), @@ -173,6 +181,143 @@ impl VHostController { } } +/// Union of all ProxyTargets' bind requirements for a VHostServer. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct VHostBindRequirements { + pub public_gateways: BTreeSet, + pub private_ips: BTreeSet, +} + +fn compute_bind_reqs(mapping: &Mapping) -> VHostBindRequirements { + let mut reqs = VHostBindRequirements::default(); + for (_, targets) in mapping { + for (target, rc) in targets { + if rc.strong_count() > 0 { + let (pub_gw, priv_ip) = target.0.bind_requirements(); + reqs.public_gateways.extend(pub_gw); + reqs.private_ips.extend(priv_ip); + } + } + } + reqs +} + +/// Listener that manages its own TCP listeners with IP-level precision. +/// Binds ALL IPs of public gateways and ONLY matching private IPs. +pub struct VHostBindListener { + ip_info: Watch>, + port: u16, + bind_reqs: Watch, + listeners: BTreeMap, +} + +fn update_vhost_listeners( + listeners: &mut BTreeMap, + port: u16, + ip_info: &OrdMap, + reqs: &VHostBindRequirements, +) -> Result<(), Error> { + let mut keep = BTreeSet::::new(); + for (gw_id, info) in ip_info { + if let Some(ip_info) = &info.ip_info { + for ipnet in &ip_info.subnets { + let ip = ipnet.addr(); + let should_bind = + reqs.public_gateways.contains(gw_id) || reqs.private_ips.contains(&ip); + if should_bind { + let addr = match ip { + IpAddr::V6(ip6) => SocketAddrV6::new( + ip6, + port, + 0, + if ipv6_is_link_local(ip6) { + ip_info.scope_id + } else { + 0 + }, + ) + .into(), + ip => SocketAddr::new(ip, port), + }; + keep.insert(addr); + if let Some((_, existing_info)) = listeners.get_mut(&addr) { + *existing_info = GatewayInfo { + id: gw_id.clone(), + info: info.clone(), + }; + } else { + let tcp = TcpListener::from_std( + mio::net::TcpListener::bind(addr) + .with_kind(ErrorKind::Network)? + .into(), + ) + .with_kind(ErrorKind::Network)?; + listeners.insert( + addr, + ( + tcp, + GatewayInfo { + id: gw_id.clone(), + info: info.clone(), + }, + ), + ); + } + } + } + } + } + listeners.retain(|key, _| keep.contains(key)); + Ok(()) +} + +impl Accept for VHostBindListener { + type Metadata = NetworkInterfaceListenerAcceptMetadata; + fn poll_accept( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + // Update listeners when ip_info or bind_reqs change + while self.ip_info.poll_changed(cx).is_ready() + || self.bind_reqs.poll_changed(cx).is_ready() + { + let reqs = self.bind_reqs.read_and_mark_seen(); + let listeners = &mut self.listeners; + let port = self.port; + self.ip_info.peek_and_mark_seen(|ip_info| { + update_vhost_listeners(listeners, port, ip_info, &reqs) + })?; + } + + // Poll each listener for incoming connections + for (&addr, (listener, gw_info)) in &self.listeners { + match listener.poll_accept(cx) { + Poll::Ready(Ok((stream, peer_addr))) => { + if let Err(e) = socket2::SockRef::from(&stream).set_keepalive(true) { + tracing::error!("Failed to set tcp keepalive: {e}"); + tracing::debug!("{e:?}"); + } + return Poll::Ready(Ok(( + NetworkInterfaceListenerAcceptMetadata { + inner: TcpMetadata { + local_addr: addr, + peer_addr, + }, + info: gw_info.clone(), + }, + Box::pin(stream), + ))); + } + Poll::Ready(Err(e)) => { + tracing::trace!("VHostBindListener accept error on {addr}: {e}"); + } + Poll::Pending => {} + } + } + Poll::Pending + } +} + pub trait VHostTarget: std::fmt::Debug + Eq { type PreprocessRes: Send + 'static; #[allow(unused_variables)] @@ -182,6 +327,10 @@ pub trait VHostTarget: std::fmt::Debug + Eq { fn acme(&self) -> Option<&AcmeProvider> { None } + /// Returns (public_gateways, private_ips) this target needs the listener to bind on. + fn bind_requirements(&self) -> (BTreeSet, BTreeSet) { + (BTreeSet::new(), BTreeSet::new()) + } fn preprocess<'a>( &'a self, prev: ServerConfig, @@ -200,6 +349,7 @@ pub trait VHostTarget: std::fmt::Debug + Eq { pub trait DynVHostTargetT: std::fmt::Debug + Any { fn filter(&self, metadata: &::Metadata) -> bool; fn acme(&self) -> Option<&AcmeProvider>; + fn bind_requirements(&self) -> (BTreeSet, BTreeSet); fn preprocess<'a>( &'a self, prev: ServerConfig, @@ -224,6 +374,9 @@ impl + 'static> DynVHostTargetT for T { fn acme(&self) -> Option<&AcmeProvider> { VHostTarget::acme(self) } + fn bind_requirements(&self) -> (BTreeSet, BTreeSet) { + VHostTarget::bind_requirements(self) + } fn preprocess<'a>( &'a self, prev: ServerConfig, @@ -301,7 +454,8 @@ impl Preprocessed { #[derive(Clone)] pub struct ProxyTarget { - pub filter: DynInterfaceFilter, + pub public: BTreeSet, + pub private: BTreeSet, pub acme: Option, pub addr: SocketAddr, pub add_x_forwarded_headers: bool, @@ -309,7 +463,8 @@ pub struct ProxyTarget { } impl PartialEq for ProxyTarget { fn eq(&self, other: &Self) -> bool { - self.filter == other.filter + self.public == other.public + && self.private == other.private && self.acme == other.acme && self.addr == other.addr && self.connect_ssl.as_ref().map(Arc::as_ptr) @@ -320,7 +475,8 @@ impl Eq for ProxyTarget {} impl fmt::Debug for ProxyTarget { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ProxyTarget") - .field("filter", &self.filter) + .field("public", &self.public) + .field("private", &self.private) .field("acme", &self.acme) .field("addr", &self.addr) .field("add_x_forwarded_headers", &self.add_x_forwarded_headers) @@ -340,16 +496,37 @@ where { type PreprocessRes = AcceptStream; fn filter(&self, metadata: &::Metadata) -> bool { - let info = extract::(metadata); - if info.is_none() { - tracing::warn!("No GatewayInfo on metadata"); + let gw = extract::(metadata); + let tcp = extract::(metadata); + let (Some(gw), Some(tcp)) = (gw, tcp) else { + return false; + }; + let Some(ip_info) = &gw.info.ip_info else { + return false; + }; + + let src = tcp.peer_addr.ip(); + // Public if: source is a gateway/router IP (NAT'd internet), + // or source is outside all known subnets (direct internet) + let is_public = ip_info.lan_ip.contains(&src) + || !ip_info.subnets.iter().any(|s| s.contains(&src)); + + if is_public { + self.public.contains(&gw.id) + } else { + // Private: accept if connection arrived on an interface with a matching IP + ip_info + .subnets + .iter() + .any(|s| self.private.contains(&s.addr())) } - info.as_ref() - .map_or(true, |i| self.filter.filter(&i.id, &i.info)) } fn acme(&self) -> Option<&AcmeProvider> { self.acme.as_ref() } + fn bind_requirements(&self) -> (BTreeSet, BTreeSet) { + (self.public.clone(), self.private.clone()) + } async fn preprocess<'a>( &'a self, mut prev: ServerConfig, @@ -634,28 +811,15 @@ where struct VHostServer { mapping: Watch>, + bind_reqs: Watch, _thread: NonDetachingJoinHandle<()>, } -impl<'a> From<&'a BTreeMap, BTreeMap>>> for AnyFilter { - fn from(value: &'a BTreeMap, BTreeMap>>) -> Self { - Self( - value - .iter() - .flat_map(|(_, v)| { - v.iter() - .filter(|(_, r)| r.strong_count() > 0) - .map(|(t, _)| t.filter.clone()) - }) - .collect(), - ) - } -} - impl VHostServer { #[instrument(skip_all)] fn new( listener: A, + bind_reqs: Watch, db: TypedPatchDb, crypto_provider: Arc, acme_cache: AcmeTlsAlpnCache, @@ -679,6 +843,7 @@ impl VHostServer { let mapping = Watch::new(BTreeMap::new()); Self { mapping: mapping.clone(), + bind_reqs, _thread: tokio::spawn(async move { let mut listener = VHostListener(TlsListener::new( listener, @@ -729,6 +894,9 @@ impl VHostServer { targets.insert(target, Arc::downgrade(&rc)); writable.insert(hostname, targets); res = Ok(rc); + if changed { + self.update_bind_reqs(writable); + } changed }); if self.mapping.watcher_count() > 1 { @@ -752,9 +920,23 @@ impl VHostServer { if !targets.is_empty() { writable.insert(hostname, targets); } + if pre != post { + self.update_bind_reqs(writable); + } pre == post }); } + fn update_bind_reqs(&self, mapping: &Mapping) { + let new_reqs = compute_bind_reqs(mapping); + self.bind_reqs.send_if_modified(|reqs| { + if *reqs != new_reqs { + *reqs = new_reqs; + true + } else { + false + } + }); + } fn is_empty(&self) -> bool { self.mapping.peek(|m| m.is_empty()) } diff --git a/core/src/net/web_server.rs b/core/src/net/web_server.rs index 2ac5b035f..8ffe9deaa 100644 --- a/core/src/net/web_server.rs +++ b/core/src/net/web_server.rs @@ -366,28 +366,6 @@ where pub struct WebServerAcceptorSetter { acceptor: Watch, } -impl WebServerAcceptorSetter>> -where - A: Accept, - B: Accept, -{ - pub fn try_upgrade Result>(&self, f: F) -> Result<(), Error> { - let mut res = Ok(()); - self.acceptor.send_modify(|a| { - *a = match a.take() { - Some(Either::Left(a)) => match f(a) { - Ok(b) => Some(Either::Right(b)), - Err(e) => { - res = Err(e); - None - } - }, - x => x, - } - }); - res - } -} impl Deref for WebServerAcceptorSetter { type Target = Watch; fn deref(&self) -> &Self::Target { diff --git a/core/src/tunnel/api.rs b/core/src/tunnel/api.rs index b8f5fd693..47f9a33e3 100644 --- a/core/src/tunnel/api.rs +++ b/core/src/tunnel/api.rs @@ -459,7 +459,7 @@ pub async fn add_forward( }) .map(|s| s.prefix_len()) .unwrap_or(32); - let rc = ctx.forward.add_forward(source, target, prefix).await?; + let rc = ctx.forward.add_forward(source, target, prefix, None).await?; ctx.active_forwards.mutate(|m| { m.insert(source, rc); }); diff --git a/core/src/tunnel/context.rs b/core/src/tunnel/context.rs index 5afac62ab..ac56eaa36 100644 --- a/core/src/tunnel/context.rs +++ b/core/src/tunnel/context.rs @@ -199,7 +199,7 @@ impl TunnelContext { }) .map(|s| s.prefix_len()) .unwrap_or(32); - active_forwards.insert(from, forward.add_forward(from, to, prefix).await?); + active_forwards.insert(from, forward.add_forward(from, to, prefix, None).await?); } Ok(Self(Arc::new(TunnelContextSeed { diff --git a/core/src/version/v0_4_0_alpha_20.rs b/core/src/version/v0_4_0_alpha_20.rs index 3da7caf5c..625d72a51 100644 --- a/core/src/version/v0_4_0_alpha_20.rs +++ b/core/src/version/v0_4_0_alpha_20.rs @@ -89,6 +89,10 @@ impl VersionT for Version { } } + // Migrate availablePorts from IdPool format to BTreeMap + // Rebuild from actual assigned ports in all bindings + migrate_available_ports(db); + Ok(Value::Null) } fn down(self, _db: &mut Value) -> Result<(), Error> { @@ -96,6 +100,62 @@ impl VersionT for Version { } } +fn collect_ports_from_host(host: Option<&Value>, ports: &mut Value) { + let Some(bindings) = host + .and_then(|h| h.get("bindings")) + .and_then(|b| b.as_object()) + else { + return; + }; + for (_, binding) in bindings.iter() { + if let Some(net) = binding.get("net") { + if let Some(port) = net.get("assignedPort").and_then(|p| p.as_u64()) { + if let Some(obj) = ports.as_object_mut() { + obj.insert(port.to_string().into(), Value::from(false)); + } + } + if let Some(port) = net.get("assignedSslPort").and_then(|p| p.as_u64()) { + if let Some(obj) = ports.as_object_mut() { + obj.insert(port.to_string().into(), Value::from(true)); + } + } + } + } +} + +fn migrate_available_ports(db: &mut Value) { + let mut new_ports: Value = serde_json::json!({}).into(); + + // Collect from server host + let server_host = db + .get("public") + .and_then(|p| p.get("serverInfo")) + .and_then(|s| s.get("network")) + .and_then(|n| n.get("host")) + .cloned(); + collect_ports_from_host(server_host.as_ref(), &mut new_ports); + + // Collect from all package hosts + if let Some(packages) = db + .get("public") + .and_then(|p| p.get("packageData")) + .and_then(|p| p.as_object()) + { + for (_, package) in packages.iter() { + if let Some(hosts) = package.get("hosts").and_then(|h| h.as_object()) { + for (_, host) in hosts.iter() { + collect_ports_from_host(Some(host), &mut new_ports); + } + } + } + } + + // Replace private.availablePorts + if let Some(private) = db.get_mut("private").and_then(|p| p.as_object_mut()) { + private.insert("availablePorts".into(), new_ports); + } +} + fn migrate_host(host: Option<&mut Value>) { let Some(host) = host.and_then(|h| h.as_object_mut()) else { return; diff --git a/sdk/package/lib/StartSdk.ts b/sdk/package/lib/StartSdk.ts index 9d1dc0164..3855d594d 100644 --- a/sdk/package/lib/StartSdk.ts +++ b/sdk/package/lib/StartSdk.ts @@ -67,7 +67,7 @@ import { import { getOwnServiceInterfaces } from '../../base/lib/util/getServiceInterfaces' import { Volumes, createVolumes } from './util/Volume' -export const OSVersion = testTypeVersion('0.4.0-alpha.19') +export const OSVersion = testTypeVersion('0.4.0-alpha.20') // prettier-ignore type AnyNeverCond = diff --git a/web/package-lock.json b/web/package-lock.json index bb0aca4f2..8c72087c8 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "startos-ui", - "version": "0.4.0-alpha.19", + "version": "0.4.0-alpha.20", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "startos-ui", - "version": "0.4.0-alpha.19", + "version": "0.4.0-alpha.20", "license": "MIT", "dependencies": { "@angular/animations": "^20.3.0", diff --git a/web/package.json b/web/package.json index 49fc3a76d..6d4e17883 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "startos-ui", - "version": "0.4.0-alpha.19", + "version": "0.4.0-alpha.20", "author": "Start9 Labs, Inc", "homepage": "https://start9.com/", "license": "MIT",