diff --git a/build/lib/scripts/forward-port b/build/lib/scripts/forward-port index 7b8aedac9..8152c6b7a 100755 --- a/build/lib/scripts/forward-port +++ b/build/lib/scripts/forward-port @@ -1,26 +1,38 @@ #!/bin/bash -if [ -z "$iiface" ] || [ -z "$oiface" ] || [ -z "$sip" ] || [ -z "$dip" ] || [ -z "$sport" ] || [ -z "$dport" ]; then +if [ -z "$sip" ] || [ -z "$dip" ] || [ -z "$sport" ] || [ -z "$dport" ]; then >&2 echo 'missing required env var' exit 1 fi -kind="-A" +# Helper function to check if a rule exists +nat_rule_exists() { + iptables -t nat -C "$@" 2>/dev/null +} + +# Helper function to add or delete a rule idempotently +# Usage: apply_rule [add|del] +apply_nat_rule() { + local action="$1" + shift + + if [ "$action" = "add" ]; then + # Only add if rule doesn't exist + if ! rule_exists "$@"; then + iptables -t nat -A "$@" + fi + elif [ "$action" = "del" ]; then + if rule_exists "$@"; then + iptables -t nat -D "$@" + fi + fi +} if [ "$UNDO" = 1 ]; then - kind="-D" + action="del" +else + action="add" fi -iptables -t nat "$kind" POSTROUTING -o $iiface -j MASQUERADE -iptables -t nat "$kind" PREROUTING -i $iiface -p tcp --dport $sport -j DNAT --to-destination $dip:$dport -iptables -t nat "$kind" PREROUTING -i $iiface -p udp --dport $sport -j DNAT --to-destination $dip:$dport -iptables -t nat "$kind" PREROUTING -i $oiface -s $dip/24 -d $sip -p tcp --dport $sport -j DNAT --to-destination $dip:$dport -iptables -t nat "$kind" PREROUTING -i $oiface -s $dip/24 -d $sip -p udp --dport $sport -j DNAT --to-destination $dip:$dport -iptables -t nat "$kind" POSTROUTING -o $oiface -s $dip/24 -d $dip/32 -p tcp --dport $dport -j SNAT --to-source $sip:$sport -iptables -t nat "$kind" POSTROUTING -o $oiface -s $dip/24 -d $dip/32 -p udp --dport $dport -j SNAT --to-source $sip:$sport - - -iptables -t nat "$kind" PREROUTING -i $iiface -s $sip/32 -d $sip -p tcp --dport $sport -j DNAT --to-destination $dip:$dport -iptables -t nat "$kind" PREROUTING -i $iiface -s $sip/32 -d $sip -p udp --dport $sport -j DNAT --to-destination $dip:$dport -iptables -t nat "$kind" POSTROUTING -o $oiface -s $sip/32 -d $dip/32 -p tcp --dport $dport -j SNAT --to-source $sip:$sport -iptables -t nat "$kind" POSTROUTING -o $oiface -s $sip/32 -d $dip/32 -p udp --dport $dport -j SNAT --to-source $sip:$sport \ No newline at end of file +apply_nat_rule "$action" PREROUTING -p tcp -d $sip --dport $sport -j DNAT --to-destination $dip:$dport +apply_nat_rule "$action" OUTPUT -p tcp -d $sip --dport $sport -j DNAT --to-destination $dip:$dport \ No newline at end of file diff --git a/core/startos/src/net/forward.rs b/core/startos/src/net/forward.rs index 979b5d70e..d02c8d1af 100644 --- a/core/startos/src/net/forward.rs +++ b/core/startos/src/net/forward.rs @@ -1,6 +1,7 @@ use std::collections::{BTreeMap, BTreeSet}; -use std::net::{IpAddr, SocketAddr, SocketAddrV6}; +use std::net::{IpAddr, SocketAddrV4}; use std::sync::{Arc, Weak}; +use std::time::Duration; use futures::channel::oneshot; use helpers::NonDetachingJoinHandle; @@ -16,7 +17,6 @@ use tokio::sync::mpsc; use crate::context::{CliContext, RpcContext}; use crate::db::model::public::NetworkInterfaceInfo; use crate::net::gateway::{DynInterfaceFilter, InterfaceFilter}; -use crate::net::utils::ipv6_is_link_local; use crate::prelude::*; use crate::util::Invoke; use crate::util::serde::{HandlerExtSerde, display_serializable}; @@ -76,51 +76,46 @@ pub fn forward_api() -> ParentHandler { } struct ForwardMapping { - source: SocketAddr, - target: SocketAddr, - interface: GatewayId, + source: SocketAddrV4, + target: SocketAddrV4, rc: Weak<()>, } #[derive(Default)] struct PortForwardState { - mappings: BTreeMap, // source -> target + mappings: BTreeMap, // source -> target } impl PortForwardState { async fn add_forward( &mut self, - interface: GatewayId, - source: SocketAddr, - target: SocketAddr, - rc: Arc<()>, + source: SocketAddrV4, + target: SocketAddrV4, ) -> Result, Error> { - // Check if mapping already exists if let Some(existing) = self.mappings.get_mut(&source) { - // If it's the same target and interface, just update the reference - if existing.target == target && existing.interface == interface { + if existing.target == target { if let Some(existing_rc) = existing.rc.upgrade() { return Ok(existing_rc); } else { + let rc = Arc::new(()); existing.rc = Arc::downgrade(&rc); return Ok(rc); } } else { - // Different target/interface, need to remove old and add new + // Different target, need to remove old and add new if let Some(mapping) = self.mappings.remove(&source) { - unforward(mapping.interface.as_str(), mapping.source, mapping.target).await?; + unforward(mapping.source, mapping.target).await?; } } } - // Add the new forward - forward(interface.as_str(), source, target).await?; + let rc = Arc::new(()); + forward(source, target).await?; self.mappings.insert( source, ForwardMapping { source, target, - interface: interface.clone(), rc: Arc::downgrade(&rc), }, ); @@ -129,7 +124,7 @@ impl PortForwardState { } async fn gc(&mut self) -> Result<(), Error> { - let to_remove: Vec = self + let to_remove: Vec = self .mappings .iter() .filter(|(_, mapping)| mapping.rc.strong_count() == 0) @@ -138,13 +133,13 @@ impl PortForwardState { for source in to_remove { if let Some(mapping) = self.mappings.remove(&source) { - unforward(mapping.interface.as_str(), mapping.source, mapping.target).await?; + unforward(mapping.source, mapping.target).await?; } } Ok(()) } - fn dump(&self) -> BTreeMap { + fn dump(&self) -> BTreeMap { self.mappings .iter() .filter(|(_, mapping)| mapping.rc.strong_count() > 0) @@ -159,9 +154,7 @@ impl Drop for PortForwardState { let mappings = std::mem::take(&mut self.mappings); tokio::spawn(async move { for (_, mapping) in mappings { - unforward(mapping.interface.as_str(), mapping.source, mapping.target) - .await - .log_err(); + unforward(mapping.source, mapping.target).await.log_err(); } }); } @@ -170,17 +163,15 @@ impl Drop for PortForwardState { enum PortForwardCommand { AddForward { - interface: GatewayId, - source: SocketAddr, - target: SocketAddr, - rc: Arc<()>, + source: SocketAddrV4, + target: SocketAddrV4, respond: oneshot::Sender, Error>>, }, Gc { respond: oneshot::Sender>, }, Dump { - respond: oneshot::Sender>, + respond: oneshot::Sender>, }, } @@ -193,17 +184,50 @@ impl PortForwardController { pub fn new() -> Self { let (req_send, mut req_recv) = mpsc::unbounded_channel::(); let thread = NonDetachingJoinHandle::from(tokio::spawn(async move { + while let Err(e) = async { + Command::new("sysctl") + .arg("-w") + .arg("net.ipv4.ip_forward=1") + .invoke(ErrorKind::Network) + .await?; + if Command::new("iptables") + .arg("-t") + .arg("nat") + .arg("-C") + .arg("POSTROUTING") + .arg("-j") + .arg("MASQUERADE") + .invoke(ErrorKind::Network) + .await + .is_err() + { + Command::new("iptables") + .arg("-t") + .arg("nat") + .arg("-A") + .arg("POSTROUTING") + .arg("-j") + .arg("MASQUERADE") + .invoke(ErrorKind::Network) + .await?; + } + Ok::<_, Error>(()) + } + .await + { + tracing::error!("error initializing PortForwardController: {e:#}"); + tracing::debug!("{e:?}"); + tokio::time::sleep(Duration::from_secs(5)).await; + } let mut state = PortForwardState::default(); while let Some(cmd) = req_recv.recv().await { match cmd { PortForwardCommand::AddForward { - interface, source, target, - rc, respond, } => { - let result = state.add_forward(interface, source, target, rc).await; + let result = state.add_forward(source, target).await; respond.send(result).ok(); } PortForwardCommand::Gc { respond } => { @@ -225,18 +249,14 @@ impl PortForwardController { pub async fn add_forward( &self, - interface: GatewayId, - source: SocketAddr, - target: SocketAddr, + source: SocketAddrV4, + target: SocketAddrV4, ) -> Result, Error> { - let rc = Arc::new(()); let (send, recv) = oneshot::channel(); self.req .send(PortForwardCommand::AddForward { - interface, source, target, - rc, respond: send, }) .map_err(err_has_exited)?; @@ -253,7 +273,7 @@ impl PortForwardController { recv.await.map_err(err_has_exited)? } - pub async fn dump(&self) -> Result, Error> { + pub async fn dump(&self) -> Result, Error> { let (send, recv) = oneshot::channel(); self.req .send(PortForwardCommand::Dump { respond: send }) @@ -265,7 +285,7 @@ impl PortForwardController { struct InterfaceForwardRequest { external: u16, - target: SocketAddr, + target: SocketAddrV4, filter: DynInterfaceFilter, rc: Arc<()>, } @@ -273,9 +293,9 @@ struct InterfaceForwardRequest { #[derive(Clone)] struct InterfaceForwardEntry { external: u16, - filter: BTreeMap)>, + filter: BTreeMap)>, // Maps source SocketAddr -> strong reference for the forward created in PortForwardController - forwards: BTreeMap>, + forwards: BTreeMap>, } impl IdOrdItem for InterfaceForwardEntry { @@ -301,7 +321,7 @@ impl InterfaceForwardEntry { ip_info: &OrdMap, port_forward: &PortForwardController, ) -> Result<(), Error> { - let mut keep = BTreeSet::::new(); + let mut keep = BTreeSet::::new(); for (iface, info) in ip_info.iter() { if let Some(target) = self @@ -312,26 +332,16 @@ impl InterfaceForwardEntry { .map(|(_, (target, _))| *target) { if let Some(ip_info) = &info.ip_info { - for ipnet in &ip_info.subnets { - let addr = match ipnet.addr() { - IpAddr::V6(ip6) => SocketAddrV6::new( - ip6, - self.external, - 0, - if ipv6_is_link_local(ip6) { - ip_info.scope_id - } else { - 0 - }, - ) - .into(), - ip => SocketAddr::new(ip, self.external), - }; + for addr in ip_info.subnets.iter().filter_map(|net| { + if let IpAddr::V4(ip) = net.addr() { + Some(SocketAddrV4::new(ip, self.external)) + } else { + None + } + }) { keep.insert(addr); if !self.forwards.contains_key(&addr) { - let rc = port_forward - .add_forward(iface.clone(), addr, target) - .await?; + let rc = port_forward.add_forward(addr, target).await?; self.forwards.insert(addr, rc); } } @@ -387,10 +397,8 @@ impl InterfaceForwardEntry { ip_info: &OrdMap, port_forward: &PortForwardController, ) -> Result<(), Error> { - // Clean up dead filter references self.filter.retain(|_, (_, rc)| rc.strong_count() > 0); - // Update to add/remove forwards based on current state (this will drop strong references as needed) self.update(ip_info, port_forward).await } } @@ -445,7 +453,7 @@ pub struct ForwardTable(pub BTreeMap); #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ForwardTarget { - pub target: SocketAddr, + pub target: SocketAddrV4, pub filter: String, } @@ -545,7 +553,7 @@ impl InterfacePortForwardController { &self, external: u16, filter: DynInterfaceFilter, - target: SocketAddr, + target: SocketAddrV4, ) -> Result, Error> { let rc = Arc::new(()); let (send, recv) = oneshot::channel(); @@ -582,16 +590,8 @@ impl InterfacePortForwardController { } } -async fn forward(interface: &str, source: SocketAddr, target: SocketAddr) -> Result<(), Error> { - if interface == START9_BRIDGE_IFACE { - return Ok(()); - } - if source.is_ipv6() { - return Ok(()); // TODO: socat? ip6tables? - } +async fn forward(source: SocketAddrV4, target: SocketAddrV4) -> Result<(), Error> { Command::new("/usr/lib/startos/scripts/forward-port") - .env("iiface", interface) - .env("oiface", START9_BRIDGE_IFACE) .env("sip", source.ip().to_string()) .env("dip", target.ip().to_string()) .env("sport", source.port().to_string()) @@ -601,17 +601,9 @@ async fn forward(interface: &str, source: SocketAddr, target: SocketAddr) -> Res Ok(()) } -async fn unforward(interface: &str, source: SocketAddr, target: SocketAddr) -> Result<(), Error> { - if interface == START9_BRIDGE_IFACE { - return Ok(()); - } - if source.is_ipv6() { - return Ok(()); // TODO: socat? ip6tables? - } +async fn unforward(source: SocketAddrV4, target: SocketAddrV4) -> Result<(), Error> { Command::new("/usr/lib/startos/scripts/forward-port") .env("UNDO", "1") - .env("iiface", interface) - .env("oiface", START9_BRIDGE_IFACE) .env("sip", source.ip().to_string()) .env("dip", target.ip().to_string()) .env("sport", source.port().to_string()) diff --git a/core/startos/src/net/net_controller.rs b/core/startos/src/net/net_controller.rs index 4b334c2ba..db9d3a175 100644 --- a/core/startos/src/net/net_controller.rs +++ b/core/startos/src/net/net_controller.rs @@ -1,5 +1,5 @@ use std::collections::{BTreeMap, BTreeSet}; -use std::net::{Ipv4Addr, SocketAddr}; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::sync::{Arc, Weak}; use color_eyre::eyre::eyre; @@ -149,7 +149,7 @@ impl NetController { #[derive(Default, Debug)] struct HostBinds { - forwards: BTreeMap)>, + forwards: BTreeMap)>, vhosts: BTreeMap<(Option, u16), (ProxyTarget, Arc<()>)>, private_dns: BTreeMap>, tor: BTreeMap, Vec>)>, @@ -241,7 +241,7 @@ impl NetServiceData { } async fn update(&mut self, ctrl: &NetController, id: HostId, host: Host) -> Result<(), Error> { - let mut forwards: BTreeMap = BTreeMap::new(); + let mut forwards: BTreeMap = BTreeMap::new(); let mut vhosts: BTreeMap<(Option, u16), ProxyTarget> = BTreeMap::new(); let mut private_dns: BTreeSet = BTreeSet::new(); let mut tor: BTreeMap)> = @@ -442,7 +442,7 @@ impl NetServiceData { forwards.insert( external, ( - (self.ip, *port).into(), + SocketAddrV4::new(self.ip, *port), AndFilter( SecureFilter { secure: bind.options.secure.is_some(), diff --git a/core/startos/src/tunnel/api.rs b/core/startos/src/tunnel/api.rs index f23b66fe5..3add85a8b 100644 --- a/core/startos/src/tunnel/api.rs +++ b/core/startos/src/tunnel/api.rs @@ -9,7 +9,6 @@ use serde::{Deserialize, Serialize}; use crate::context::CliContext; use crate::prelude::*; use crate::tunnel::context::TunnelContext; -use crate::tunnel::db::GatewayPort; use crate::tunnel::wg::{WgConfig, WgSubnetClients, WgSubnetConfig}; use crate::util::serde::{HandlerExtSerde, display_serializable}; @@ -359,7 +358,7 @@ pub async fn show_config( #[derive(Deserialize, Serialize, Parser)] #[serde(rename_all = "camelCase")] pub struct AddPortForwardParams { - source: GatewayPort, + source: SocketAddrV4, target: SocketAddrV4, } @@ -372,26 +371,7 @@ pub async fn add_forward( .await .result?; - // source is (GatewayId, port), target is SocketAddrV4 - // Find the first IP address for the specified gateway and create a source SocketAddr - let source_addr = ctx.net_iface.peek(|ifaces| { - ifaces - .get(&source.0) - .and_then(|info| info.ip_info.as_ref()) - .and_then(|ip_info| ip_info.subnets.iter().next()) - .map(|ipnet| std::net::SocketAddr::new(ipnet.addr(), source.1)) - }) - .ok_or_else(|| { - Error::new( - eyre!("Gateway {} not found or has no IP addresses", source.0), - crate::ErrorKind::Network, - ) - })?; - - let rc = ctx - .forward - .add_forward(source.0.clone(), source_addr, target.into()) - .await?; + let rc = ctx.forward.add_forward(source, target).await?; ctx.active_forwards.mutate(|m| { m.insert(source, rc); }); @@ -401,7 +381,7 @@ pub async fn add_forward( #[derive(Deserialize, Serialize, Parser)] #[serde(rename_all = "camelCase")] pub struct RemovePortForwardParams { - source: GatewayPort, + source: SocketAddrV4, } pub async fn remove_forward( diff --git a/core/startos/src/tunnel/context.rs b/core/startos/src/tunnel/context.rs index 4a4cfaf86..f646c343b 100644 --- a/core/startos/src/tunnel/context.rs +++ b/core/startos/src/tunnel/context.rs @@ -1,5 +1,5 @@ use std::collections::BTreeMap; -use std::net::{IpAddr, SocketAddr}; +use std::net::{IpAddr, SocketAddr, SocketAddrV4}; use std::ops::Deref; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -33,7 +33,7 @@ use crate::prelude::*; use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations}; use crate::tunnel::TUNNEL_DEFAULT_LISTEN; use crate::tunnel::api::tunnel_api; -use crate::tunnel::db::{GatewayPort, TunnelDatabase}; +use crate::tunnel::db::TunnelDatabase; use crate::tunnel::wg::WIREGUARD_INTERFACE_NAME; use crate::util::Invoke; use crate::util::collections::OrdMapIterMut; @@ -85,7 +85,7 @@ pub struct TunnelContextSeed { pub ephemeral_sessions: SyncMutex, pub net_iface: Watch>, pub forward: PortForwardController, - pub active_forwards: SyncMutex>>, + pub active_forwards: SyncMutex>>, pub shutdown: Sender<()>, } @@ -182,24 +182,7 @@ impl TunnelContext { peek.as_wg().de()?.sync().await?; let mut active_forwards = BTreeMap::new(); for (from, to) in peek.as_port_forwards().de()?.0 { - // from is (GatewayId, u16), to is SocketAddr - // Create the source SocketAddr for each interface matching the gateway - for (gateway_id, info) in net_iface.peek(|i| i.clone()) { - if gateway_id == from.0 { - if let Some(ip_info) = &info.ip_info { - if let Some(ipnet) = ip_info.subnets.iter().next() { - let source = std::net::SocketAddr::new(ipnet.addr(), from.1); - active_forwards.insert( - from.clone(), - forward - .add_forward(gateway_id.clone(), source, to.into()) - .await?, - ); - } - } - break; - } - } + active_forwards.insert(from, forward.add_forward(from, to).await?); } Ok(Self(Arc::new(TunnelContextSeed { diff --git a/core/startos/src/tunnel/db.rs b/core/startos/src/tunnel/db.rs index d11a3501d..216f71007 100644 --- a/core/startos/src/tunnel/db.rs +++ b/core/startos/src/tunnel/db.rs @@ -31,51 +31,6 @@ use crate::tunnel::wg::WgServer; use crate::util::net::WebSocketExt; use crate::util::serde::{HandlerExtSerde, apply_expr, deserialize_from_str, serialize_display}; -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct GatewayPort(pub GatewayId, pub u16); -impl std::fmt::Display for GatewayPort { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}:{}", self.0, self.1) - } -} -impl std::str::FromStr for GatewayPort { - type Err = crate::Error; - fn from_str(s: &str) -> Result { - let mut parts = s.splitn(2, ':'); - let gw: GatewayId = parts - .next() - .ok_or_else(|| Error::new(eyre!("missing gateway id"), ErrorKind::ParseNetAddress))? - .parse()?; - let port: u16 = parts - .next() - .ok_or_else(|| Error::new(eyre!("missing port"), ErrorKind::ParseNetAddress))? - .parse()?; - Ok(GatewayPort(gw, port)) - } -} -impl Serialize for GatewayPort { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serialize_display(self, serializer) - } -} -impl<'de> Deserialize<'de> for GatewayPort { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserialize_from_str(deserializer) - } -} -impl ValueParserFactory for GatewayPort { - type Parser = FromStrParser; - fn value_parser() -> Self::Parser { - FromStrParser::new() - } -} - #[derive(Default, Deserialize, Serialize, HasModel)] #[serde(rename_all = "camelCase")] #[model = "Model"] @@ -90,9 +45,9 @@ pub struct TunnelDatabase { } #[derive(Clone, Debug, Default, Deserialize, Serialize)] -pub struct PortForwards(pub BTreeMap); +pub struct PortForwards(pub BTreeMap); impl Map for PortForwards { - type Key = GatewayPort; + type Key = SocketAddrV4; type Value = SocketAddrV4; fn key_str(key: &Self::Key) -> Result, Error> { Self::key_string(key)