From 5ae9a555cec3eee14e276d5e5a7c0d7fdbfbdf20 Mon Sep 17 00:00:00 2001 From: Aiden McClelland Date: Fri, 31 Oct 2025 12:21:02 -0600 Subject: [PATCH] wip: separate port forward controller into parts --- core/startos/src/net/forward.rs | 344 ++++++++++++++++++++----- core/startos/src/net/net_controller.rs | 6 +- core/startos/src/tunnel/api.rs | 24 +- core/startos/src/tunnel/context.rs | 27 +- 4 files changed, 315 insertions(+), 86 deletions(-) diff --git a/core/startos/src/net/forward.rs b/core/startos/src/net/forward.rs index 92deaecbe..979b5d70e 100644 --- a/core/startos/src/net/forward.rs +++ b/core/startos/src/net/forward.rs @@ -75,7 +75,195 @@ pub fn forward_api() -> ParentHandler { ) } -struct ForwardRequest { +struct ForwardMapping { + source: SocketAddr, + target: SocketAddr, + interface: GatewayId, + rc: Weak<()>, +} + +#[derive(Default)] +struct PortForwardState { + mappings: BTreeMap, // source -> target +} + +impl PortForwardState { + async fn add_forward( + &mut self, + interface: GatewayId, + source: SocketAddr, + target: SocketAddr, + rc: Arc<()>, + ) -> Result, Error> { + // Check if mapping already exists + if let Some(existing) = self.mappings.get_mut(&source) { + // If it's the same target and interface, just update the reference + if existing.target == target && existing.interface == interface { + if let Some(existing_rc) = existing.rc.upgrade() { + return Ok(existing_rc); + } else { + existing.rc = Arc::downgrade(&rc); + return Ok(rc); + } + } else { + // Different target/interface, need to remove old and add new + if let Some(mapping) = self.mappings.remove(&source) { + unforward(mapping.interface.as_str(), mapping.source, mapping.target).await?; + } + } + } + + // Add the new forward + forward(interface.as_str(), source, target).await?; + self.mappings.insert( + source, + ForwardMapping { + source, + target, + interface: interface.clone(), + rc: Arc::downgrade(&rc), + }, + ); + + Ok(rc) + } + + async fn gc(&mut self) -> Result<(), Error> { + let to_remove: Vec = self + .mappings + .iter() + .filter(|(_, mapping)| mapping.rc.strong_count() == 0) + .map(|(source, _)| *source) + .collect(); + + for source in to_remove { + if let Some(mapping) = self.mappings.remove(&source) { + unforward(mapping.interface.as_str(), mapping.source, mapping.target).await?; + } + } + Ok(()) + } + + fn dump(&self) -> BTreeMap { + self.mappings + .iter() + .filter(|(_, mapping)| mapping.rc.strong_count() > 0) + .map(|(source, mapping)| (*source, mapping.target)) + .collect() + } +} + +impl Drop for PortForwardState { + fn drop(&mut self) { + if !self.mappings.is_empty() { + let mappings = std::mem::take(&mut self.mappings); + tokio::spawn(async move { + for (_, mapping) in mappings { + unforward(mapping.interface.as_str(), mapping.source, mapping.target) + .await + .log_err(); + } + }); + } + } +} + +enum PortForwardCommand { + AddForward { + interface: GatewayId, + source: SocketAddr, + target: SocketAddr, + rc: Arc<()>, + respond: oneshot::Sender, Error>>, + }, + Gc { + respond: oneshot::Sender>, + }, + Dump { + respond: oneshot::Sender>, + }, +} + +pub struct PortForwardController { + req: mpsc::UnboundedSender, + _thread: NonDetachingJoinHandle<()>, +} + +impl PortForwardController { + pub fn new() -> Self { + let (req_send, mut req_recv) = mpsc::unbounded_channel::(); + let thread = NonDetachingJoinHandle::from(tokio::spawn(async move { + let mut state = PortForwardState::default(); + while let Some(cmd) = req_recv.recv().await { + match cmd { + PortForwardCommand::AddForward { + interface, + source, + target, + rc, + respond, + } => { + let result = state.add_forward(interface, source, target, rc).await; + respond.send(result).ok(); + } + PortForwardCommand::Gc { respond } => { + let result = state.gc().await; + respond.send(result).ok(); + } + PortForwardCommand::Dump { respond } => { + respond.send(state.dump()).ok(); + } + } + } + })); + + Self { + req: req_send, + _thread: thread, + } + } + + pub async fn add_forward( + &self, + interface: GatewayId, + source: SocketAddr, + target: SocketAddr, + ) -> Result, Error> { + let rc = Arc::new(()); + let (send, recv) = oneshot::channel(); + self.req + .send(PortForwardCommand::AddForward { + interface, + source, + target, + rc, + respond: send, + }) + .map_err(err_has_exited)?; + + recv.await.map_err(err_has_exited)? + } + + pub async fn gc(&self) -> Result<(), Error> { + let (send, recv) = oneshot::channel(); + self.req + .send(PortForwardCommand::Gc { respond: send }) + .map_err(err_has_exited)?; + + recv.await.map_err(err_has_exited)? + } + + pub async fn dump(&self) -> Result, Error> { + let (send, recv) = oneshot::channel(); + self.req + .send(PortForwardCommand::Dump { respond: send }) + .map_err(err_has_exited)?; + + recv.await.map_err(err_has_exited) + } +} + +struct InterfaceForwardRequest { external: u16, target: SocketAddr, filter: DynInterfaceFilter, @@ -83,12 +271,14 @@ struct ForwardRequest { } #[derive(Clone)] -struct ForwardEntry { +struct InterfaceForwardEntry { external: u16, filter: BTreeMap)>, - forwards: BTreeMap, + // Maps source SocketAddr -> strong reference for the forward created in PortForwardController + forwards: BTreeMap>, } -impl IdOrdItem for ForwardEntry { + +impl IdOrdItem for InterfaceForwardEntry { type Key<'a> = u16; fn key(&self) -> Self::Key<'_> { self.external @@ -96,7 +286,8 @@ impl IdOrdItem for ForwardEntry { iddqd::id_upcast!(); } -impl ForwardEntry { + +impl InterfaceForwardEntry { fn new(external: u16) -> Self { Self { external, @@ -105,26 +296,13 @@ impl ForwardEntry { } } - fn take(&mut self) -> Self { - Self { - external: self.external, - filter: std::mem::take(&mut self.filter), - forwards: std::mem::take(&mut self.forwards), - } - } - - async fn destroy(mut self) -> Result<(), Error> { - while let Some((source, (interface, target))) = self.forwards.pop_first() { - unforward(interface.as_str(), source, target).await?; - } - Ok(()) - } - async fn update( &mut self, ip_info: &OrdMap, + port_forward: &PortForwardController, ) -> Result<(), Error> { let mut keep = BTreeSet::::new(); + for (iface, info) in ip_info.iter() { if let Some(target) = self .filter @@ -151,40 +329,36 @@ impl ForwardEntry { }; keep.insert(addr); if !self.forwards.contains_key(&addr) { - forward(iface.as_str(), addr, target).await?; - self.forwards.insert(addr, (iface.clone(), target)); + let rc = port_forward + .add_forward(iface.clone(), addr, target) + .await?; + self.forwards.insert(addr, rc); } } } } } - let rm = self - .forwards - .keys() - .copied() - .filter(|a| !keep.contains(a)) - .collect::>(); - for rm in rm { - if let Some((source, (interface, target))) = self.forwards.remove_entry(&rm) { - unforward(interface.as_str(), source, target).await?; - } - } + + // Remove forwards that should no longer exist (drops the strong references) + self.forwards.retain(|addr, _| keep.contains(addr)); + Ok(()) } async fn update_request( &mut self, - ForwardRequest { + InterfaceForwardRequest { external, target, filter, mut rc, - }: ForwardRequest, + }: InterfaceForwardRequest, ip_info: &OrdMap, + port_forward: &PortForwardController, ) -> Result, Error> { if external != self.external { return Err(Error::new( - eyre!("Mismatched external port in ForwardEntry"), + eyre!("Mismatched external port in InterfaceForwardEntry"), ErrorKind::InvalidRequest, )); } @@ -203,44 +377,57 @@ impl ForwardEntry { entry.1 = Arc::downgrade(&rc); } - self.update(ip_info).await?; + self.update(ip_info, port_forward).await?; Ok(rc) } + + async fn gc( + &mut self, + ip_info: &OrdMap, + port_forward: &PortForwardController, + ) -> Result<(), Error> { + // Clean up dead filter references + self.filter.retain(|_, (_, rc)| rc.strong_count() > 0); + + // Update to add/remove forwards based on current state (this will drop strong references as needed) + self.update(ip_info, port_forward).await + } } -impl Drop for ForwardEntry { - fn drop(&mut self) { - if !self.forwards.is_empty() { - let take = self.take(); - tokio::spawn(async move { - take.destroy().await.log_err(); - }); + +struct InterfaceForwardState { + port_forward: PortForwardController, + state: IdOrdMap, +} + +impl InterfaceForwardState { + fn new(port_forward: PortForwardController) -> Self { + Self { + port_forward, + state: IdOrdMap::new(), } } } -#[derive(Default, Clone)] -struct ForwardState { - state: IdOrdMap, -} -impl ForwardState { +impl InterfaceForwardState { async fn handle_request( &mut self, - request: ForwardRequest, + request: InterfaceForwardRequest, ip_info: &OrdMap, ) -> Result, Error> { self.state .entry(request.external) - .or_insert_with(|| ForwardEntry::new(request.external)) - .update_request(request, ip_info) + .or_insert_with(|| InterfaceForwardEntry::new(request.external)) + .update_request(request, ip_info, &self.port_forward) .await } + async fn sync( &mut self, ip_info: &OrdMap, ) -> Result<(), Error> { for mut entry in self.state.iter_mut() { - entry.update(ip_info).await?; + entry.gc(ip_info, &self.port_forward).await?; } Ok(()) } @@ -255,13 +442,15 @@ fn err_has_exited(_: T) -> Error { #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ForwardTable(pub BTreeMap); + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ForwardTarget { pub target: SocketAddr, pub filter: String, } -impl From<&ForwardState> for ForwardTable { - fn from(value: &ForwardState) -> Self { + +impl From<&InterfaceForwardState> for ForwardTable { + fn from(value: &InterfaceForwardState) -> Self { Self( value .state @@ -286,8 +475,11 @@ impl From<&ForwardState> for ForwardTable { } } -enum ForwardCommand { - Forward(ForwardRequest, oneshot::Sender, Error>>), +enum InterfaceForwardCommand { + Forward( + InterfaceForwardRequest, + oneshot::Sender, Error>>, + ), Sync(oneshot::Sender>), DumpTable(oneshot::Sender), } @@ -302,24 +494,33 @@ fn test() { ); } -pub struct PortForwardController { - req: mpsc::UnboundedSender, +pub struct InterfacePortForwardController { + req: mpsc::UnboundedSender, _thread: NonDetachingJoinHandle<()>, } -impl PortForwardController { + +impl InterfacePortForwardController { pub fn new(mut ip_info: Watch>) -> Self { - let (req_send, mut req_recv) = mpsc::unbounded_channel::(); + let port_forward = PortForwardController::new(); + + let (req_send, mut req_recv) = mpsc::unbounded_channel::(); let thread = NonDetachingJoinHandle::from(tokio::spawn(async move { - let mut state = ForwardState::default(); + let mut state = InterfaceForwardState::new(port_forward); let mut interfaces = ip_info.read_and_mark_seen(); loop { tokio::select! { msg = req_recv.recv() => { if let Some(cmd) = msg { match cmd { - ForwardCommand::Forward(req, re) => re.send(state.handle_request(req, &interfaces).await).ok(), - ForwardCommand::Sync(re) => re.send(state.sync(&interfaces).await).ok(), - ForwardCommand::DumpTable(re) => re.send((&state).into()).ok(), + InterfaceForwardCommand::Forward(req, re) => { + re.send(state.handle_request(req, &interfaces).await).ok() + } + InterfaceForwardCommand::Sync(re) => { + re.send(state.sync(&interfaces).await).ok() + } + InterfaceForwardCommand::DumpTable(re) => { + re.send((&state).into()).ok() + } }; } else { break; @@ -328,15 +529,18 @@ impl PortForwardController { _ = ip_info.changed() => { interfaces = ip_info.read(); state.sync(&interfaces).await.log_err(); + state.port_forward.gc().await.log_err(); } } } })); + Self { req: req_send, _thread: thread, } } + pub async fn add( &self, external: u16, @@ -346,8 +550,8 @@ impl PortForwardController { let rc = Arc::new(()); let (send, recv) = oneshot::channel(); self.req - .send(ForwardCommand::Forward( - ForwardRequest { + .send(InterfaceForwardCommand::Forward( + InterfaceForwardRequest { external, target, filter, @@ -359,18 +563,20 @@ impl PortForwardController { recv.await.map_err(err_has_exited)? } + pub async fn gc(&self) -> Result<(), Error> { let (send, recv) = oneshot::channel(); self.req - .send(ForwardCommand::Sync(send)) + .send(InterfaceForwardCommand::Sync(send)) .map_err(err_has_exited)?; recv.await.map_err(err_has_exited)? } + pub async fn dump_table(&self) -> Result { let (req, res) = oneshot::channel(); self.req - .send(ForwardCommand::DumpTable(req)) + .send(InterfaceForwardCommand::DumpTable(req)) .map_err(err_has_exited)?; res.await.map_err(err_has_exited) } diff --git a/core/startos/src/net/net_controller.rs b/core/startos/src/net/net_controller.rs index 7f1497234..4b334c2ba 100644 --- a/core/startos/src/net/net_controller.rs +++ b/core/startos/src/net/net_controller.rs @@ -18,7 +18,7 @@ use crate::db::model::public::NetworkInterfaceType; use crate::error::ErrorCollection; use crate::hostname::Hostname; use crate::net::dns::DnsController; -use crate::net::forward::{PortForwardController, START9_BRIDGE_IFACE}; +use crate::net::forward::{InterfacePortForwardController, START9_BRIDGE_IFACE}; use crate::net::gateway::{ AndFilter, DynInterfaceFilter, IdFilter, InterfaceFilter, NetworkInterfaceController, OrFilter, PublicFilter, SecureFilter, TypeFilter, @@ -42,7 +42,7 @@ pub struct NetController { pub(super) tls_client_config: Arc, pub(crate) net_iface: Arc, pub(super) dns: DnsController, - pub(super) forward: PortForwardController, + pub(super) forward: InterfacePortForwardController, pub(super) socks: SocksController, pub(super) server_hostnames: Vec>, pub(crate) callbacks: Arc, @@ -76,7 +76,7 @@ impl NetController { vhost: VHostController::new(db.clone(), net_iface.clone(), crypto_provider), tls_client_config, dns: DnsController::init(db, &net_iface.watcher).await?, - forward: PortForwardController::new(net_iface.watcher.subscribe()), + forward: InterfacePortForwardController::new(net_iface.watcher.subscribe()), net_iface, socks, server_hostnames: vec![ diff --git a/core/startos/src/tunnel/api.rs b/core/startos/src/tunnel/api.rs index 238143fc0..f23b66fe5 100644 --- a/core/startos/src/tunnel/api.rs +++ b/core/startos/src/tunnel/api.rs @@ -7,7 +7,6 @@ use rpc_toolkit::{Context, Empty, HandlerArgs, HandlerExt, ParentHandler, from_f use serde::{Deserialize, Serialize}; use crate::context::CliContext; -use crate::net::gateway::{IdFilter, InterfaceFilter}; use crate::prelude::*; use crate::tunnel::context::TunnelContext; use crate::tunnel::db::GatewayPort; @@ -372,13 +371,26 @@ pub async fn add_forward( .mutate(|db| db.as_port_forwards_mut().insert(&source, &target)) .await .result?; + + // source is (GatewayId, port), target is SocketAddrV4 + // Find the first IP address for the specified gateway and create a source SocketAddr + let source_addr = ctx.net_iface.peek(|ifaces| { + ifaces + .get(&source.0) + .and_then(|info| info.ip_info.as_ref()) + .and_then(|ip_info| ip_info.subnets.iter().next()) + .map(|ipnet| std::net::SocketAddr::new(ipnet.addr(), source.1)) + }) + .ok_or_else(|| { + Error::new( + eyre!("Gateway {} not found or has no IP addresses", source.0), + crate::ErrorKind::Network, + ) + })?; + let rc = ctx .forward - .add( - source.1, - IdFilter(source.0.clone()).into_dyn(), - target.into(), - ) + .add_forward(source.0.clone(), source_addr, target.into()) .await?; ctx.active_forwards.mutate(|m| { m.insert(source, rc); diff --git a/core/startos/src/tunnel/context.rs b/core/startos/src/tunnel/context.rs index 51b32774e..4a4cfaf86 100644 --- a/core/startos/src/tunnel/context.rs +++ b/core/startos/src/tunnel/context.rs @@ -28,7 +28,6 @@ use crate::else_empty_dir; use crate::middleware::auth::{Auth, AuthContext}; use crate::middleware::cors::Cors; use crate::net::forward::PortForwardController; -use crate::net::gateway::{IdFilter, InterfaceFilter}; use crate::net::static_server::UiContext; use crate::prelude::*; use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations}; @@ -130,7 +129,7 @@ impl TunnelContext { .await .result?; let net_iface = Watch::new(net_iface); - let forward = PortForwardController::new(net_iface.clone_unseen()); + let forward = PortForwardController::new(); Command::new("sysctl") .arg("-w") @@ -183,12 +182,24 @@ impl TunnelContext { peek.as_wg().de()?.sync().await?; let mut active_forwards = BTreeMap::new(); for (from, to) in peek.as_port_forwards().de()?.0 { - active_forwards.insert( - from.clone(), - forward - .add(from.1, IdFilter(from.0).into_dyn(), to.into()) - .await?, - ); + // from is (GatewayId, u16), to is SocketAddr + // Create the source SocketAddr for each interface matching the gateway + for (gateway_id, info) in net_iface.peek(|i| i.clone()) { + if gateway_id == from.0 { + if let Some(ip_info) = &info.ip_info { + if let Some(ipnet) = ip_info.subnets.iter().next() { + let source = std::net::SocketAddr::new(ipnet.addr(), from.1); + active_forwards.insert( + from.clone(), + forward + .add_forward(gateway_id.clone(), source, to.into()) + .await?, + ); + } + } + break; + } + } } Ok(Self(Arc::new(TunnelContextSeed {