wip: separate port forward controller into parts

This commit is contained in:
Aiden McClelland
2025-10-31 12:21:02 -06:00
parent afc69b13a0
commit 5ae9a555ce
4 changed files with 315 additions and 86 deletions

View File

@@ -75,7 +75,195 @@ pub fn forward_api<C: Context>() -> ParentHandler<C> {
)
}
struct ForwardRequest {
struct ForwardMapping {
source: SocketAddr,
target: SocketAddr,
interface: GatewayId,
rc: Weak<()>,
}
#[derive(Default)]
struct PortForwardState {
mappings: BTreeMap<SocketAddr, ForwardMapping>, // source -> target
}
impl PortForwardState {
async fn add_forward(
&mut self,
interface: GatewayId,
source: SocketAddr,
target: SocketAddr,
rc: Arc<()>,
) -> Result<Arc<()>, Error> {
// Check if mapping already exists
if let Some(existing) = self.mappings.get_mut(&source) {
// If it's the same target and interface, just update the reference
if existing.target == target && existing.interface == interface {
if let Some(existing_rc) = existing.rc.upgrade() {
return Ok(existing_rc);
} else {
existing.rc = Arc::downgrade(&rc);
return Ok(rc);
}
} else {
// Different target/interface, need to remove old and add new
if let Some(mapping) = self.mappings.remove(&source) {
unforward(mapping.interface.as_str(), mapping.source, mapping.target).await?;
}
}
}
// Add the new forward
forward(interface.as_str(), source, target).await?;
self.mappings.insert(
source,
ForwardMapping {
source,
target,
interface: interface.clone(),
rc: Arc::downgrade(&rc),
},
);
Ok(rc)
}
async fn gc(&mut self) -> Result<(), Error> {
let to_remove: Vec<SocketAddr> = self
.mappings
.iter()
.filter(|(_, mapping)| mapping.rc.strong_count() == 0)
.map(|(source, _)| *source)
.collect();
for source in to_remove {
if let Some(mapping) = self.mappings.remove(&source) {
unforward(mapping.interface.as_str(), mapping.source, mapping.target).await?;
}
}
Ok(())
}
fn dump(&self) -> BTreeMap<SocketAddr, SocketAddr> {
self.mappings
.iter()
.filter(|(_, mapping)| mapping.rc.strong_count() > 0)
.map(|(source, mapping)| (*source, mapping.target))
.collect()
}
}
impl Drop for PortForwardState {
fn drop(&mut self) {
if !self.mappings.is_empty() {
let mappings = std::mem::take(&mut self.mappings);
tokio::spawn(async move {
for (_, mapping) in mappings {
unforward(mapping.interface.as_str(), mapping.source, mapping.target)
.await
.log_err();
}
});
}
}
}
enum PortForwardCommand {
AddForward {
interface: GatewayId,
source: SocketAddr,
target: SocketAddr,
rc: Arc<()>,
respond: oneshot::Sender<Result<Arc<()>, Error>>,
},
Gc {
respond: oneshot::Sender<Result<(), Error>>,
},
Dump {
respond: oneshot::Sender<BTreeMap<SocketAddr, SocketAddr>>,
},
}
pub struct PortForwardController {
req: mpsc::UnboundedSender<PortForwardCommand>,
_thread: NonDetachingJoinHandle<()>,
}
impl PortForwardController {
pub fn new() -> Self {
let (req_send, mut req_recv) = mpsc::unbounded_channel::<PortForwardCommand>();
let thread = NonDetachingJoinHandle::from(tokio::spawn(async move {
let mut state = PortForwardState::default();
while let Some(cmd) = req_recv.recv().await {
match cmd {
PortForwardCommand::AddForward {
interface,
source,
target,
rc,
respond,
} => {
let result = state.add_forward(interface, source, target, rc).await;
respond.send(result).ok();
}
PortForwardCommand::Gc { respond } => {
let result = state.gc().await;
respond.send(result).ok();
}
PortForwardCommand::Dump { respond } => {
respond.send(state.dump()).ok();
}
}
}
}));
Self {
req: req_send,
_thread: thread,
}
}
pub async fn add_forward(
&self,
interface: GatewayId,
source: SocketAddr,
target: SocketAddr,
) -> Result<Arc<()>, Error> {
let rc = Arc::new(());
let (send, recv) = oneshot::channel();
self.req
.send(PortForwardCommand::AddForward {
interface,
source,
target,
rc,
respond: send,
})
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)?
}
pub async fn gc(&self) -> Result<(), Error> {
let (send, recv) = oneshot::channel();
self.req
.send(PortForwardCommand::Gc { respond: send })
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)?
}
pub async fn dump(&self) -> Result<BTreeMap<SocketAddr, SocketAddr>, Error> {
let (send, recv) = oneshot::channel();
self.req
.send(PortForwardCommand::Dump { respond: send })
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)
}
}
struct InterfaceForwardRequest {
external: u16,
target: SocketAddr,
filter: DynInterfaceFilter,
@@ -83,12 +271,14 @@ struct ForwardRequest {
}
#[derive(Clone)]
struct ForwardEntry {
struct InterfaceForwardEntry {
external: u16,
filter: BTreeMap<DynInterfaceFilter, (SocketAddr, Weak<()>)>,
forwards: BTreeMap<SocketAddr, (GatewayId, SocketAddr)>,
// Maps source SocketAddr -> strong reference for the forward created in PortForwardController
forwards: BTreeMap<SocketAddr, Arc<()>>,
}
impl IdOrdItem for ForwardEntry {
impl IdOrdItem for InterfaceForwardEntry {
type Key<'a> = u16;
fn key(&self) -> Self::Key<'_> {
self.external
@@ -96,7 +286,8 @@ impl IdOrdItem for ForwardEntry {
iddqd::id_upcast!();
}
impl ForwardEntry {
impl InterfaceForwardEntry {
fn new(external: u16) -> Self {
Self {
external,
@@ -105,26 +296,13 @@ impl ForwardEntry {
}
}
fn take(&mut self) -> Self {
Self {
external: self.external,
filter: std::mem::take(&mut self.filter),
forwards: std::mem::take(&mut self.forwards),
}
}
async fn destroy(mut self) -> Result<(), Error> {
while let Some((source, (interface, target))) = self.forwards.pop_first() {
unforward(interface.as_str(), source, target).await?;
}
Ok(())
}
async fn update(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
port_forward: &PortForwardController,
) -> Result<(), Error> {
let mut keep = BTreeSet::<SocketAddr>::new();
for (iface, info) in ip_info.iter() {
if let Some(target) = self
.filter
@@ -151,40 +329,36 @@ impl ForwardEntry {
};
keep.insert(addr);
if !self.forwards.contains_key(&addr) {
forward(iface.as_str(), addr, target).await?;
self.forwards.insert(addr, (iface.clone(), target));
let rc = port_forward
.add_forward(iface.clone(), addr, target)
.await?;
self.forwards.insert(addr, rc);
}
}
}
}
}
let rm = self
.forwards
.keys()
.copied()
.filter(|a| !keep.contains(a))
.collect::<Vec<_>>();
for rm in rm {
if let Some((source, (interface, target))) = self.forwards.remove_entry(&rm) {
unforward(interface.as_str(), source, target).await?;
}
}
// Remove forwards that should no longer exist (drops the strong references)
self.forwards.retain(|addr, _| keep.contains(addr));
Ok(())
}
async fn update_request(
&mut self,
ForwardRequest {
InterfaceForwardRequest {
external,
target,
filter,
mut rc,
}: ForwardRequest,
}: InterfaceForwardRequest,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
port_forward: &PortForwardController,
) -> Result<Arc<()>, Error> {
if external != self.external {
return Err(Error::new(
eyre!("Mismatched external port in ForwardEntry"),
eyre!("Mismatched external port in InterfaceForwardEntry"),
ErrorKind::InvalidRequest,
));
}
@@ -203,44 +377,57 @@ impl ForwardEntry {
entry.1 = Arc::downgrade(&rc);
}
self.update(ip_info).await?;
self.update(ip_info, port_forward).await?;
Ok(rc)
}
async fn gc(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
port_forward: &PortForwardController,
) -> Result<(), Error> {
// Clean up dead filter references
self.filter.retain(|_, (_, rc)| rc.strong_count() > 0);
// Update to add/remove forwards based on current state (this will drop strong references as needed)
self.update(ip_info, port_forward).await
}
}
impl Drop for ForwardEntry {
fn drop(&mut self) {
if !self.forwards.is_empty() {
let take = self.take();
tokio::spawn(async move {
take.destroy().await.log_err();
});
struct InterfaceForwardState {
port_forward: PortForwardController,
state: IdOrdMap<InterfaceForwardEntry>,
}
impl InterfaceForwardState {
fn new(port_forward: PortForwardController) -> Self {
Self {
port_forward,
state: IdOrdMap::new(),
}
}
}
#[derive(Default, Clone)]
struct ForwardState {
state: IdOrdMap<ForwardEntry>,
}
impl ForwardState {
impl InterfaceForwardState {
async fn handle_request(
&mut self,
request: ForwardRequest,
request: InterfaceForwardRequest,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
) -> Result<Arc<()>, Error> {
self.state
.entry(request.external)
.or_insert_with(|| ForwardEntry::new(request.external))
.update_request(request, ip_info)
.or_insert_with(|| InterfaceForwardEntry::new(request.external))
.update_request(request, ip_info, &self.port_forward)
.await
}
async fn sync(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
) -> Result<(), Error> {
for mut entry in self.state.iter_mut() {
entry.update(ip_info).await?;
entry.gc(ip_info, &self.port_forward).await?;
}
Ok(())
}
@@ -255,13 +442,15 @@ fn err_has_exited<T>(_: T) -> Error {
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ForwardTable(pub BTreeMap<u16, ForwardTarget>);
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ForwardTarget {
pub target: SocketAddr,
pub filter: String,
}
impl From<&ForwardState> for ForwardTable {
fn from(value: &ForwardState) -> Self {
impl From<&InterfaceForwardState> for ForwardTable {
fn from(value: &InterfaceForwardState) -> Self {
Self(
value
.state
@@ -286,8 +475,11 @@ impl From<&ForwardState> for ForwardTable {
}
}
enum ForwardCommand {
Forward(ForwardRequest, oneshot::Sender<Result<Arc<()>, Error>>),
enum InterfaceForwardCommand {
Forward(
InterfaceForwardRequest,
oneshot::Sender<Result<Arc<()>, Error>>,
),
Sync(oneshot::Sender<Result<(), Error>>),
DumpTable(oneshot::Sender<ForwardTable>),
}
@@ -302,24 +494,33 @@ fn test() {
);
}
pub struct PortForwardController {
req: mpsc::UnboundedSender<ForwardCommand>,
pub struct InterfacePortForwardController {
req: mpsc::UnboundedSender<InterfaceForwardCommand>,
_thread: NonDetachingJoinHandle<()>,
}
impl PortForwardController {
impl InterfacePortForwardController {
pub fn new(mut ip_info: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>) -> Self {
let (req_send, mut req_recv) = mpsc::unbounded_channel::<ForwardCommand>();
let port_forward = PortForwardController::new();
let (req_send, mut req_recv) = mpsc::unbounded_channel::<InterfaceForwardCommand>();
let thread = NonDetachingJoinHandle::from(tokio::spawn(async move {
let mut state = ForwardState::default();
let mut state = InterfaceForwardState::new(port_forward);
let mut interfaces = ip_info.read_and_mark_seen();
loop {
tokio::select! {
msg = req_recv.recv() => {
if let Some(cmd) = msg {
match cmd {
ForwardCommand::Forward(req, re) => re.send(state.handle_request(req, &interfaces).await).ok(),
ForwardCommand::Sync(re) => re.send(state.sync(&interfaces).await).ok(),
ForwardCommand::DumpTable(re) => re.send((&state).into()).ok(),
InterfaceForwardCommand::Forward(req, re) => {
re.send(state.handle_request(req, &interfaces).await).ok()
}
InterfaceForwardCommand::Sync(re) => {
re.send(state.sync(&interfaces).await).ok()
}
InterfaceForwardCommand::DumpTable(re) => {
re.send((&state).into()).ok()
}
};
} else {
break;
@@ -328,15 +529,18 @@ impl PortForwardController {
_ = ip_info.changed() => {
interfaces = ip_info.read();
state.sync(&interfaces).await.log_err();
state.port_forward.gc().await.log_err();
}
}
}
}));
Self {
req: req_send,
_thread: thread,
}
}
pub async fn add(
&self,
external: u16,
@@ -346,8 +550,8 @@ impl PortForwardController {
let rc = Arc::new(());
let (send, recv) = oneshot::channel();
self.req
.send(ForwardCommand::Forward(
ForwardRequest {
.send(InterfaceForwardCommand::Forward(
InterfaceForwardRequest {
external,
target,
filter,
@@ -359,18 +563,20 @@ impl PortForwardController {
recv.await.map_err(err_has_exited)?
}
pub async fn gc(&self) -> Result<(), Error> {
let (send, recv) = oneshot::channel();
self.req
.send(ForwardCommand::Sync(send))
.send(InterfaceForwardCommand::Sync(send))
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)?
}
pub async fn dump_table(&self) -> Result<ForwardTable, Error> {
let (req, res) = oneshot::channel();
self.req
.send(ForwardCommand::DumpTable(req))
.send(InterfaceForwardCommand::DumpTable(req))
.map_err(err_has_exited)?;
res.await.map_err(err_has_exited)
}

View File

@@ -18,7 +18,7 @@ use crate::db::model::public::NetworkInterfaceType;
use crate::error::ErrorCollection;
use crate::hostname::Hostname;
use crate::net::dns::DnsController;
use crate::net::forward::{PortForwardController, START9_BRIDGE_IFACE};
use crate::net::forward::{InterfacePortForwardController, START9_BRIDGE_IFACE};
use crate::net::gateway::{
AndFilter, DynInterfaceFilter, IdFilter, InterfaceFilter, NetworkInterfaceController, OrFilter,
PublicFilter, SecureFilter, TypeFilter,
@@ -42,7 +42,7 @@ pub struct NetController {
pub(super) tls_client_config: Arc<TlsClientConfig>,
pub(crate) net_iface: Arc<NetworkInterfaceController>,
pub(super) dns: DnsController,
pub(super) forward: PortForwardController,
pub(super) forward: InterfacePortForwardController,
pub(super) socks: SocksController,
pub(super) server_hostnames: Vec<Option<InternedString>>,
pub(crate) callbacks: Arc<ServiceCallbacks>,
@@ -76,7 +76,7 @@ impl NetController {
vhost: VHostController::new(db.clone(), net_iface.clone(), crypto_provider),
tls_client_config,
dns: DnsController::init(db, &net_iface.watcher).await?,
forward: PortForwardController::new(net_iface.watcher.subscribe()),
forward: InterfacePortForwardController::new(net_iface.watcher.subscribe()),
net_iface,
socks,
server_hostnames: vec![

View File

@@ -7,7 +7,6 @@ use rpc_toolkit::{Context, Empty, HandlerArgs, HandlerExt, ParentHandler, from_f
use serde::{Deserialize, Serialize};
use crate::context::CliContext;
use crate::net::gateway::{IdFilter, InterfaceFilter};
use crate::prelude::*;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::db::GatewayPort;
@@ -372,13 +371,26 @@ pub async fn add_forward(
.mutate(|db| db.as_port_forwards_mut().insert(&source, &target))
.await
.result?;
// source is (GatewayId, port), target is SocketAddrV4
// Find the first IP address for the specified gateway and create a source SocketAddr
let source_addr = ctx.net_iface.peek(|ifaces| {
ifaces
.get(&source.0)
.and_then(|info| info.ip_info.as_ref())
.and_then(|ip_info| ip_info.subnets.iter().next())
.map(|ipnet| std::net::SocketAddr::new(ipnet.addr(), source.1))
})
.ok_or_else(|| {
Error::new(
eyre!("Gateway {} not found or has no IP addresses", source.0),
crate::ErrorKind::Network,
)
})?;
let rc = ctx
.forward
.add(
source.1,
IdFilter(source.0.clone()).into_dyn(),
target.into(),
)
.add_forward(source.0.clone(), source_addr, target.into())
.await?;
ctx.active_forwards.mutate(|m| {
m.insert(source, rc);

View File

@@ -28,7 +28,6 @@ use crate::else_empty_dir;
use crate::middleware::auth::{Auth, AuthContext};
use crate::middleware::cors::Cors;
use crate::net::forward::PortForwardController;
use crate::net::gateway::{IdFilter, InterfaceFilter};
use crate::net::static_server::UiContext;
use crate::prelude::*;
use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations};
@@ -130,7 +129,7 @@ impl TunnelContext {
.await
.result?;
let net_iface = Watch::new(net_iface);
let forward = PortForwardController::new(net_iface.clone_unseen());
let forward = PortForwardController::new();
Command::new("sysctl")
.arg("-w")
@@ -183,12 +182,24 @@ impl TunnelContext {
peek.as_wg().de()?.sync().await?;
let mut active_forwards = BTreeMap::new();
for (from, to) in peek.as_port_forwards().de()?.0 {
active_forwards.insert(
from.clone(),
forward
.add(from.1, IdFilter(from.0).into_dyn(), to.into())
.await?,
);
// from is (GatewayId, u16), to is SocketAddr
// Create the source SocketAddr for each interface matching the gateway
for (gateway_id, info) in net_iface.peek(|i| i.clone()) {
if gateway_id == from.0 {
if let Some(ip_info) = &info.ip_info {
if let Some(ipnet) = ip_info.subnets.iter().next() {
let source = std::net::SocketAddr::new(ipnet.addr(), from.1);
active_forwards.insert(
from.clone(),
forward
.add_forward(gateway_id.clone(), source, to.into())
.await?,
);
}
}
break;
}
}
}
Ok(Self(Arc::new(TunnelContextSeed {