simplify iptables rules

This commit is contained in:
Aiden McClelland
2025-10-31 15:41:00 -06:00
parent 5ae9a555ce
commit 5852bcadf8
6 changed files with 116 additions and 194 deletions

View File

@@ -1,26 +1,38 @@
#!/bin/bash #!/bin/bash
if [ -z "$iiface" ] || [ -z "$oiface" ] || [ -z "$sip" ] || [ -z "$dip" ] || [ -z "$sport" ] || [ -z "$dport" ]; then if [ -z "$sip" ] || [ -z "$dip" ] || [ -z "$sport" ] || [ -z "$dport" ]; then
>&2 echo 'missing required env var' >&2 echo 'missing required env var'
exit 1 exit 1
fi fi
kind="-A" # Helper function to check if a rule exists
nat_rule_exists() {
iptables -t nat -C "$@" 2>/dev/null
}
# Helper function to add or delete a rule idempotently
# Usage: apply_rule [add|del] <iptables args...>
apply_nat_rule() {
local action="$1"
shift
if [ "$action" = "add" ]; then
# Only add if rule doesn't exist
if ! rule_exists "$@"; then
iptables -t nat -A "$@"
fi
elif [ "$action" = "del" ]; then
if rule_exists "$@"; then
iptables -t nat -D "$@"
fi
fi
}
if [ "$UNDO" = 1 ]; then if [ "$UNDO" = 1 ]; then
kind="-D" action="del"
else
action="add"
fi fi
iptables -t nat "$kind" POSTROUTING -o $iiface -j MASQUERADE apply_nat_rule "$action" PREROUTING -p tcp -d $sip --dport $sport -j DNAT --to-destination $dip:$dport
iptables -t nat "$kind" PREROUTING -i $iiface -p tcp --dport $sport -j DNAT --to-destination $dip:$dport apply_nat_rule "$action" OUTPUT -p tcp -d $sip --dport $sport -j DNAT --to-destination $dip:$dport
iptables -t nat "$kind" PREROUTING -i $iiface -p udp --dport $sport -j DNAT --to-destination $dip:$dport
iptables -t nat "$kind" PREROUTING -i $oiface -s $dip/24 -d $sip -p tcp --dport $sport -j DNAT --to-destination $dip:$dport
iptables -t nat "$kind" PREROUTING -i $oiface -s $dip/24 -d $sip -p udp --dport $sport -j DNAT --to-destination $dip:$dport
iptables -t nat "$kind" POSTROUTING -o $oiface -s $dip/24 -d $dip/32 -p tcp --dport $dport -j SNAT --to-source $sip:$sport
iptables -t nat "$kind" POSTROUTING -o $oiface -s $dip/24 -d $dip/32 -p udp --dport $dport -j SNAT --to-source $sip:$sport
iptables -t nat "$kind" PREROUTING -i $iiface -s $sip/32 -d $sip -p tcp --dport $sport -j DNAT --to-destination $dip:$dport
iptables -t nat "$kind" PREROUTING -i $iiface -s $sip/32 -d $sip -p udp --dport $sport -j DNAT --to-destination $dip:$dport
iptables -t nat "$kind" POSTROUTING -o $oiface -s $sip/32 -d $dip/32 -p tcp --dport $dport -j SNAT --to-source $sip:$sport
iptables -t nat "$kind" POSTROUTING -o $oiface -s $sip/32 -d $dip/32 -p udp --dport $dport -j SNAT --to-source $sip:$sport

View File

@@ -1,6 +1,7 @@
use std::collections::{BTreeMap, BTreeSet}; use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, SocketAddr, SocketAddrV6}; use std::net::{IpAddr, SocketAddrV4};
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
use std::time::Duration;
use futures::channel::oneshot; use futures::channel::oneshot;
use helpers::NonDetachingJoinHandle; use helpers::NonDetachingJoinHandle;
@@ -16,7 +17,6 @@ use tokio::sync::mpsc;
use crate::context::{CliContext, RpcContext}; use crate::context::{CliContext, RpcContext};
use crate::db::model::public::NetworkInterfaceInfo; use crate::db::model::public::NetworkInterfaceInfo;
use crate::net::gateway::{DynInterfaceFilter, InterfaceFilter}; use crate::net::gateway::{DynInterfaceFilter, InterfaceFilter};
use crate::net::utils::ipv6_is_link_local;
use crate::prelude::*; use crate::prelude::*;
use crate::util::Invoke; use crate::util::Invoke;
use crate::util::serde::{HandlerExtSerde, display_serializable}; use crate::util::serde::{HandlerExtSerde, display_serializable};
@@ -76,51 +76,46 @@ pub fn forward_api<C: Context>() -> ParentHandler<C> {
} }
struct ForwardMapping { struct ForwardMapping {
source: SocketAddr, source: SocketAddrV4,
target: SocketAddr, target: SocketAddrV4,
interface: GatewayId,
rc: Weak<()>, rc: Weak<()>,
} }
#[derive(Default)] #[derive(Default)]
struct PortForwardState { struct PortForwardState {
mappings: BTreeMap<SocketAddr, ForwardMapping>, // source -> target mappings: BTreeMap<SocketAddrV4, ForwardMapping>, // source -> target
} }
impl PortForwardState { impl PortForwardState {
async fn add_forward( async fn add_forward(
&mut self, &mut self,
interface: GatewayId, source: SocketAddrV4,
source: SocketAddr, target: SocketAddrV4,
target: SocketAddr,
rc: Arc<()>,
) -> Result<Arc<()>, Error> { ) -> Result<Arc<()>, Error> {
// Check if mapping already exists
if let Some(existing) = self.mappings.get_mut(&source) { if let Some(existing) = self.mappings.get_mut(&source) {
// If it's the same target and interface, just update the reference if existing.target == target {
if existing.target == target && existing.interface == interface {
if let Some(existing_rc) = existing.rc.upgrade() { if let Some(existing_rc) = existing.rc.upgrade() {
return Ok(existing_rc); return Ok(existing_rc);
} else { } else {
let rc = Arc::new(());
existing.rc = Arc::downgrade(&rc); existing.rc = Arc::downgrade(&rc);
return Ok(rc); return Ok(rc);
} }
} else { } else {
// Different target/interface, need to remove old and add new // Different target, need to remove old and add new
if let Some(mapping) = self.mappings.remove(&source) { if let Some(mapping) = self.mappings.remove(&source) {
unforward(mapping.interface.as_str(), mapping.source, mapping.target).await?; unforward(mapping.source, mapping.target).await?;
} }
} }
} }
// Add the new forward let rc = Arc::new(());
forward(interface.as_str(), source, target).await?; forward(source, target).await?;
self.mappings.insert( self.mappings.insert(
source, source,
ForwardMapping { ForwardMapping {
source, source,
target, target,
interface: interface.clone(),
rc: Arc::downgrade(&rc), rc: Arc::downgrade(&rc),
}, },
); );
@@ -129,7 +124,7 @@ impl PortForwardState {
} }
async fn gc(&mut self) -> Result<(), Error> { async fn gc(&mut self) -> Result<(), Error> {
let to_remove: Vec<SocketAddr> = self let to_remove: Vec<SocketAddrV4> = self
.mappings .mappings
.iter() .iter()
.filter(|(_, mapping)| mapping.rc.strong_count() == 0) .filter(|(_, mapping)| mapping.rc.strong_count() == 0)
@@ -138,13 +133,13 @@ impl PortForwardState {
for source in to_remove { for source in to_remove {
if let Some(mapping) = self.mappings.remove(&source) { if let Some(mapping) = self.mappings.remove(&source) {
unforward(mapping.interface.as_str(), mapping.source, mapping.target).await?; unforward(mapping.source, mapping.target).await?;
} }
} }
Ok(()) Ok(())
} }
fn dump(&self) -> BTreeMap<SocketAddr, SocketAddr> { fn dump(&self) -> BTreeMap<SocketAddrV4, SocketAddrV4> {
self.mappings self.mappings
.iter() .iter()
.filter(|(_, mapping)| mapping.rc.strong_count() > 0) .filter(|(_, mapping)| mapping.rc.strong_count() > 0)
@@ -159,9 +154,7 @@ impl Drop for PortForwardState {
let mappings = std::mem::take(&mut self.mappings); let mappings = std::mem::take(&mut self.mappings);
tokio::spawn(async move { tokio::spawn(async move {
for (_, mapping) in mappings { for (_, mapping) in mappings {
unforward(mapping.interface.as_str(), mapping.source, mapping.target) unforward(mapping.source, mapping.target).await.log_err();
.await
.log_err();
} }
}); });
} }
@@ -170,17 +163,15 @@ impl Drop for PortForwardState {
enum PortForwardCommand { enum PortForwardCommand {
AddForward { AddForward {
interface: GatewayId, source: SocketAddrV4,
source: SocketAddr, target: SocketAddrV4,
target: SocketAddr,
rc: Arc<()>,
respond: oneshot::Sender<Result<Arc<()>, Error>>, respond: oneshot::Sender<Result<Arc<()>, Error>>,
}, },
Gc { Gc {
respond: oneshot::Sender<Result<(), Error>>, respond: oneshot::Sender<Result<(), Error>>,
}, },
Dump { Dump {
respond: oneshot::Sender<BTreeMap<SocketAddr, SocketAddr>>, respond: oneshot::Sender<BTreeMap<SocketAddrV4, SocketAddrV4>>,
}, },
} }
@@ -193,17 +184,50 @@ impl PortForwardController {
pub fn new() -> Self { pub fn new() -> Self {
let (req_send, mut req_recv) = mpsc::unbounded_channel::<PortForwardCommand>(); let (req_send, mut req_recv) = mpsc::unbounded_channel::<PortForwardCommand>();
let thread = NonDetachingJoinHandle::from(tokio::spawn(async move { let thread = NonDetachingJoinHandle::from(tokio::spawn(async move {
while let Err(e) = async {
Command::new("sysctl")
.arg("-w")
.arg("net.ipv4.ip_forward=1")
.invoke(ErrorKind::Network)
.await?;
if Command::new("iptables")
.arg("-t")
.arg("nat")
.arg("-C")
.arg("POSTROUTING")
.arg("-j")
.arg("MASQUERADE")
.invoke(ErrorKind::Network)
.await
.is_err()
{
Command::new("iptables")
.arg("-t")
.arg("nat")
.arg("-A")
.arg("POSTROUTING")
.arg("-j")
.arg("MASQUERADE")
.invoke(ErrorKind::Network)
.await?;
}
Ok::<_, Error>(())
}
.await
{
tracing::error!("error initializing PortForwardController: {e:#}");
tracing::debug!("{e:?}");
tokio::time::sleep(Duration::from_secs(5)).await;
}
let mut state = PortForwardState::default(); let mut state = PortForwardState::default();
while let Some(cmd) = req_recv.recv().await { while let Some(cmd) = req_recv.recv().await {
match cmd { match cmd {
PortForwardCommand::AddForward { PortForwardCommand::AddForward {
interface,
source, source,
target, target,
rc,
respond, respond,
} => { } => {
let result = state.add_forward(interface, source, target, rc).await; let result = state.add_forward(source, target).await;
respond.send(result).ok(); respond.send(result).ok();
} }
PortForwardCommand::Gc { respond } => { PortForwardCommand::Gc { respond } => {
@@ -225,18 +249,14 @@ impl PortForwardController {
pub async fn add_forward( pub async fn add_forward(
&self, &self,
interface: GatewayId, source: SocketAddrV4,
source: SocketAddr, target: SocketAddrV4,
target: SocketAddr,
) -> Result<Arc<()>, Error> { ) -> Result<Arc<()>, Error> {
let rc = Arc::new(());
let (send, recv) = oneshot::channel(); let (send, recv) = oneshot::channel();
self.req self.req
.send(PortForwardCommand::AddForward { .send(PortForwardCommand::AddForward {
interface,
source, source,
target, target,
rc,
respond: send, respond: send,
}) })
.map_err(err_has_exited)?; .map_err(err_has_exited)?;
@@ -253,7 +273,7 @@ impl PortForwardController {
recv.await.map_err(err_has_exited)? recv.await.map_err(err_has_exited)?
} }
pub async fn dump(&self) -> Result<BTreeMap<SocketAddr, SocketAddr>, Error> { pub async fn dump(&self) -> Result<BTreeMap<SocketAddrV4, SocketAddrV4>, Error> {
let (send, recv) = oneshot::channel(); let (send, recv) = oneshot::channel();
self.req self.req
.send(PortForwardCommand::Dump { respond: send }) .send(PortForwardCommand::Dump { respond: send })
@@ -265,7 +285,7 @@ impl PortForwardController {
struct InterfaceForwardRequest { struct InterfaceForwardRequest {
external: u16, external: u16,
target: SocketAddr, target: SocketAddrV4,
filter: DynInterfaceFilter, filter: DynInterfaceFilter,
rc: Arc<()>, rc: Arc<()>,
} }
@@ -273,9 +293,9 @@ struct InterfaceForwardRequest {
#[derive(Clone)] #[derive(Clone)]
struct InterfaceForwardEntry { struct InterfaceForwardEntry {
external: u16, external: u16,
filter: BTreeMap<DynInterfaceFilter, (SocketAddr, Weak<()>)>, filter: BTreeMap<DynInterfaceFilter, (SocketAddrV4, Weak<()>)>,
// Maps source SocketAddr -> strong reference for the forward created in PortForwardController // Maps source SocketAddr -> strong reference for the forward created in PortForwardController
forwards: BTreeMap<SocketAddr, Arc<()>>, forwards: BTreeMap<SocketAddrV4, Arc<()>>,
} }
impl IdOrdItem for InterfaceForwardEntry { impl IdOrdItem for InterfaceForwardEntry {
@@ -301,7 +321,7 @@ impl InterfaceForwardEntry {
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>, ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
port_forward: &PortForwardController, port_forward: &PortForwardController,
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut keep = BTreeSet::<SocketAddr>::new(); let mut keep = BTreeSet::<SocketAddrV4>::new();
for (iface, info) in ip_info.iter() { for (iface, info) in ip_info.iter() {
if let Some(target) = self if let Some(target) = self
@@ -312,26 +332,16 @@ impl InterfaceForwardEntry {
.map(|(_, (target, _))| *target) .map(|(_, (target, _))| *target)
{ {
if let Some(ip_info) = &info.ip_info { if let Some(ip_info) = &info.ip_info {
for ipnet in &ip_info.subnets { for addr in ip_info.subnets.iter().filter_map(|net| {
let addr = match ipnet.addr() { if let IpAddr::V4(ip) = net.addr() {
IpAddr::V6(ip6) => SocketAddrV6::new( Some(SocketAddrV4::new(ip, self.external))
ip6, } else {
self.external, None
0, }
if ipv6_is_link_local(ip6) { }) {
ip_info.scope_id
} else {
0
},
)
.into(),
ip => SocketAddr::new(ip, self.external),
};
keep.insert(addr); keep.insert(addr);
if !self.forwards.contains_key(&addr) { if !self.forwards.contains_key(&addr) {
let rc = port_forward let rc = port_forward.add_forward(addr, target).await?;
.add_forward(iface.clone(), addr, target)
.await?;
self.forwards.insert(addr, rc); self.forwards.insert(addr, rc);
} }
} }
@@ -387,10 +397,8 @@ impl InterfaceForwardEntry {
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>, ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
port_forward: &PortForwardController, port_forward: &PortForwardController,
) -> Result<(), Error> { ) -> Result<(), Error> {
// Clean up dead filter references
self.filter.retain(|_, (_, rc)| rc.strong_count() > 0); self.filter.retain(|_, (_, rc)| rc.strong_count() > 0);
// Update to add/remove forwards based on current state (this will drop strong references as needed)
self.update(ip_info, port_forward).await self.update(ip_info, port_forward).await
} }
} }
@@ -445,7 +453,7 @@ pub struct ForwardTable(pub BTreeMap<u16, ForwardTarget>);
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ForwardTarget { pub struct ForwardTarget {
pub target: SocketAddr, pub target: SocketAddrV4,
pub filter: String, pub filter: String,
} }
@@ -545,7 +553,7 @@ impl InterfacePortForwardController {
&self, &self,
external: u16, external: u16,
filter: DynInterfaceFilter, filter: DynInterfaceFilter,
target: SocketAddr, target: SocketAddrV4,
) -> Result<Arc<()>, Error> { ) -> Result<Arc<()>, Error> {
let rc = Arc::new(()); let rc = Arc::new(());
let (send, recv) = oneshot::channel(); let (send, recv) = oneshot::channel();
@@ -582,16 +590,8 @@ impl InterfacePortForwardController {
} }
} }
async fn forward(interface: &str, source: SocketAddr, target: SocketAddr) -> Result<(), Error> { async fn forward(source: SocketAddrV4, target: SocketAddrV4) -> Result<(), Error> {
if interface == START9_BRIDGE_IFACE {
return Ok(());
}
if source.is_ipv6() {
return Ok(()); // TODO: socat? ip6tables?
}
Command::new("/usr/lib/startos/scripts/forward-port") Command::new("/usr/lib/startos/scripts/forward-port")
.env("iiface", interface)
.env("oiface", START9_BRIDGE_IFACE)
.env("sip", source.ip().to_string()) .env("sip", source.ip().to_string())
.env("dip", target.ip().to_string()) .env("dip", target.ip().to_string())
.env("sport", source.port().to_string()) .env("sport", source.port().to_string())
@@ -601,17 +601,9 @@ async fn forward(interface: &str, source: SocketAddr, target: SocketAddr) -> Res
Ok(()) Ok(())
} }
async fn unforward(interface: &str, source: SocketAddr, target: SocketAddr) -> Result<(), Error> { async fn unforward(source: SocketAddrV4, target: SocketAddrV4) -> Result<(), Error> {
if interface == START9_BRIDGE_IFACE {
return Ok(());
}
if source.is_ipv6() {
return Ok(()); // TODO: socat? ip6tables?
}
Command::new("/usr/lib/startos/scripts/forward-port") Command::new("/usr/lib/startos/scripts/forward-port")
.env("UNDO", "1") .env("UNDO", "1")
.env("iiface", interface)
.env("oiface", START9_BRIDGE_IFACE)
.env("sip", source.ip().to_string()) .env("sip", source.ip().to_string())
.env("dip", target.ip().to_string()) .env("dip", target.ip().to_string())
.env("sport", source.port().to_string()) .env("sport", source.port().to_string())

View File

@@ -1,5 +1,5 @@
use std::collections::{BTreeMap, BTreeSet}; use std::collections::{BTreeMap, BTreeSet};
use std::net::{Ipv4Addr, SocketAddr}; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
use color_eyre::eyre::eyre; use color_eyre::eyre::eyre;
@@ -149,7 +149,7 @@ impl NetController {
#[derive(Default, Debug)] #[derive(Default, Debug)]
struct HostBinds { struct HostBinds {
forwards: BTreeMap<u16, (SocketAddr, DynInterfaceFilter, Arc<()>)>, forwards: BTreeMap<u16, (SocketAddrV4, DynInterfaceFilter, Arc<()>)>,
vhosts: BTreeMap<(Option<InternedString>, u16), (ProxyTarget, Arc<()>)>, vhosts: BTreeMap<(Option<InternedString>, u16), (ProxyTarget, Arc<()>)>,
private_dns: BTreeMap<InternedString, Arc<()>>, private_dns: BTreeMap<InternedString, Arc<()>>,
tor: BTreeMap<OnionAddress, (OrdMap<u16, SocketAddr>, Vec<Arc<()>>)>, tor: BTreeMap<OnionAddress, (OrdMap<u16, SocketAddr>, Vec<Arc<()>>)>,
@@ -241,7 +241,7 @@ impl NetServiceData {
} }
async fn update(&mut self, ctrl: &NetController, id: HostId, host: Host) -> Result<(), Error> { async fn update(&mut self, ctrl: &NetController, id: HostId, host: Host) -> Result<(), Error> {
let mut forwards: BTreeMap<u16, (SocketAddr, DynInterfaceFilter)> = BTreeMap::new(); let mut forwards: BTreeMap<u16, (SocketAddrV4, DynInterfaceFilter)> = BTreeMap::new();
let mut vhosts: BTreeMap<(Option<InternedString>, u16), ProxyTarget> = BTreeMap::new(); let mut vhosts: BTreeMap<(Option<InternedString>, u16), ProxyTarget> = BTreeMap::new();
let mut private_dns: BTreeSet<InternedString> = BTreeSet::new(); let mut private_dns: BTreeSet<InternedString> = BTreeSet::new();
let mut tor: BTreeMap<OnionAddress, (TorSecretKey, OrdMap<u16, SocketAddr>)> = let mut tor: BTreeMap<OnionAddress, (TorSecretKey, OrdMap<u16, SocketAddr>)> =
@@ -442,7 +442,7 @@ impl NetServiceData {
forwards.insert( forwards.insert(
external, external,
( (
(self.ip, *port).into(), SocketAddrV4::new(self.ip, *port),
AndFilter( AndFilter(
SecureFilter { SecureFilter {
secure: bind.options.secure.is_some(), secure: bind.options.secure.is_some(),

View File

@@ -9,7 +9,6 @@ use serde::{Deserialize, Serialize};
use crate::context::CliContext; use crate::context::CliContext;
use crate::prelude::*; use crate::prelude::*;
use crate::tunnel::context::TunnelContext; use crate::tunnel::context::TunnelContext;
use crate::tunnel::db::GatewayPort;
use crate::tunnel::wg::{WgConfig, WgSubnetClients, WgSubnetConfig}; use crate::tunnel::wg::{WgConfig, WgSubnetClients, WgSubnetConfig};
use crate::util::serde::{HandlerExtSerde, display_serializable}; use crate::util::serde::{HandlerExtSerde, display_serializable};
@@ -359,7 +358,7 @@ pub async fn show_config(
#[derive(Deserialize, Serialize, Parser)] #[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct AddPortForwardParams { pub struct AddPortForwardParams {
source: GatewayPort, source: SocketAddrV4,
target: SocketAddrV4, target: SocketAddrV4,
} }
@@ -372,26 +371,7 @@ pub async fn add_forward(
.await .await
.result?; .result?;
// source is (GatewayId, port), target is SocketAddrV4 let rc = ctx.forward.add_forward(source, target).await?;
// Find the first IP address for the specified gateway and create a source SocketAddr
let source_addr = ctx.net_iface.peek(|ifaces| {
ifaces
.get(&source.0)
.and_then(|info| info.ip_info.as_ref())
.and_then(|ip_info| ip_info.subnets.iter().next())
.map(|ipnet| std::net::SocketAddr::new(ipnet.addr(), source.1))
})
.ok_or_else(|| {
Error::new(
eyre!("Gateway {} not found or has no IP addresses", source.0),
crate::ErrorKind::Network,
)
})?;
let rc = ctx
.forward
.add_forward(source.0.clone(), source_addr, target.into())
.await?;
ctx.active_forwards.mutate(|m| { ctx.active_forwards.mutate(|m| {
m.insert(source, rc); m.insert(source, rc);
}); });
@@ -401,7 +381,7 @@ pub async fn add_forward(
#[derive(Deserialize, Serialize, Parser)] #[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct RemovePortForwardParams { pub struct RemovePortForwardParams {
source: GatewayPort, source: SocketAddrV4,
} }
pub async fn remove_forward( pub async fn remove_forward(

View File

@@ -1,5 +1,5 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr, SocketAddrV4};
use std::ops::Deref; use std::ops::Deref;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
@@ -33,7 +33,7 @@ use crate::prelude::*;
use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations}; use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations};
use crate::tunnel::TUNNEL_DEFAULT_LISTEN; use crate::tunnel::TUNNEL_DEFAULT_LISTEN;
use crate::tunnel::api::tunnel_api; use crate::tunnel::api::tunnel_api;
use crate::tunnel::db::{GatewayPort, TunnelDatabase}; use crate::tunnel::db::TunnelDatabase;
use crate::tunnel::wg::WIREGUARD_INTERFACE_NAME; use crate::tunnel::wg::WIREGUARD_INTERFACE_NAME;
use crate::util::Invoke; use crate::util::Invoke;
use crate::util::collections::OrdMapIterMut; use crate::util::collections::OrdMapIterMut;
@@ -85,7 +85,7 @@ pub struct TunnelContextSeed {
pub ephemeral_sessions: SyncMutex<Sessions>, pub ephemeral_sessions: SyncMutex<Sessions>,
pub net_iface: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>, pub net_iface: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
pub forward: PortForwardController, pub forward: PortForwardController,
pub active_forwards: SyncMutex<BTreeMap<GatewayPort, Arc<()>>>, pub active_forwards: SyncMutex<BTreeMap<SocketAddrV4, Arc<()>>>,
pub shutdown: Sender<()>, pub shutdown: Sender<()>,
} }
@@ -182,24 +182,7 @@ impl TunnelContext {
peek.as_wg().de()?.sync().await?; peek.as_wg().de()?.sync().await?;
let mut active_forwards = BTreeMap::new(); let mut active_forwards = BTreeMap::new();
for (from, to) in peek.as_port_forwards().de()?.0 { for (from, to) in peek.as_port_forwards().de()?.0 {
// from is (GatewayId, u16), to is SocketAddr active_forwards.insert(from, forward.add_forward(from, to).await?);
// Create the source SocketAddr for each interface matching the gateway
for (gateway_id, info) in net_iface.peek(|i| i.clone()) {
if gateway_id == from.0 {
if let Some(ip_info) = &info.ip_info {
if let Some(ipnet) = ip_info.subnets.iter().next() {
let source = std::net::SocketAddr::new(ipnet.addr(), from.1);
active_forwards.insert(
from.clone(),
forward
.add_forward(gateway_id.clone(), source, to.into())
.await?,
);
}
}
break;
}
}
} }
Ok(Self(Arc::new(TunnelContextSeed { Ok(Self(Arc::new(TunnelContextSeed {

View File

@@ -31,51 +31,6 @@ use crate::tunnel::wg::WgServer;
use crate::util::net::WebSocketExt; use crate::util::net::WebSocketExt;
use crate::util::serde::{HandlerExtSerde, apply_expr, deserialize_from_str, serialize_display}; use crate::util::serde::{HandlerExtSerde, apply_expr, deserialize_from_str, serialize_display};
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct GatewayPort(pub GatewayId, pub u16);
impl std::fmt::Display for GatewayPort {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.0, self.1)
}
}
impl std::str::FromStr for GatewayPort {
type Err = crate::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut parts = s.splitn(2, ':');
let gw: GatewayId = parts
.next()
.ok_or_else(|| Error::new(eyre!("missing gateway id"), ErrorKind::ParseNetAddress))?
.parse()?;
let port: u16 = parts
.next()
.ok_or_else(|| Error::new(eyre!("missing port"), ErrorKind::ParseNetAddress))?
.parse()?;
Ok(GatewayPort(gw, port))
}
}
impl Serialize for GatewayPort {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serialize_display(self, serializer)
}
}
impl<'de> Deserialize<'de> for GatewayPort {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserialize_from_str(deserializer)
}
}
impl ValueParserFactory for GatewayPort {
type Parser = FromStrParser<Self>;
fn value_parser() -> Self::Parser {
FromStrParser::new()
}
}
#[derive(Default, Deserialize, Serialize, HasModel)] #[derive(Default, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
#[model = "Model<Self>"] #[model = "Model<Self>"]
@@ -90,9 +45,9 @@ pub struct TunnelDatabase {
} }
#[derive(Clone, Debug, Default, Deserialize, Serialize)] #[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct PortForwards(pub BTreeMap<GatewayPort, SocketAddrV4>); pub struct PortForwards(pub BTreeMap<SocketAddrV4, SocketAddrV4>);
impl Map for PortForwards { impl Map for PortForwards {
type Key = GatewayPort; type Key = SocketAddrV4;
type Value = SocketAddrV4; type Value = SocketAddrV4;
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> { fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
Self::key_string(key) Self::key_string(key)