Files
start-os/core/src/net/forward.rs
2026-02-23 13:35:34 -07:00

758 lines
23 KiB
Rust

use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, SocketAddrV4};
use std::sync::{Arc, Weak};
use std::time::Duration;
use futures::channel::oneshot;
use iddqd::{IdOrdItem, IdOrdMap};
use imbl::OrdMap;
use ipnet::{IpNet, Ipv4Net};
use rand::Rng;
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use tokio::sync::mpsc;
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::NetworkInterfaceInfo;
use crate::prelude::*;
use crate::util::Invoke;
use crate::util::future::NonDetachingJoinHandle;
use crate::util::serde::{HandlerExtSerde, display_serializable};
use crate::util::sync::Watch;
use crate::{GatewayId, HOST_IP};
pub const START9_BRIDGE_IFACE: &str = "lxcbr0";
const EPHEMERAL_PORT_START: u16 = 49152;
// vhost.rs:89 — not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353
const RESTRICTED_PORTS: &[u16] = &[5353, 5355, 5432, 6010, 9050, 9051];
fn is_restricted(port: u16) -> bool {
port <= 1024 || RESTRICTED_PORTS.contains(&port)
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct ForwardRequirements {
pub public_gateways: BTreeSet<GatewayId>,
pub private_ips: BTreeSet<IpAddr>,
pub secure: bool,
}
impl std::fmt::Display for ForwardRequirements {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ForwardRequirements {{ public: {:?}, private: {:?}, secure: {} }}",
self.public_gateways, self.private_ips, self.secure
)
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct AvailablePorts(BTreeMap<u16, bool>);
impl AvailablePorts {
pub fn new() -> Self {
Self(BTreeMap::new())
}
pub fn alloc(&mut self, ssl: bool) -> Result<u16, Error> {
let mut rng = rand::rng();
for _ in 0..1000 {
let port = rng.random_range(EPHEMERAL_PORT_START..u16::MAX);
if !self.0.contains_key(&port) {
self.0.insert(port, ssl);
return Ok(port);
}
}
Err(Error::new(
eyre!("{}", t!("net.forward.no-dynamic-ports-available")),
ErrorKind::Network,
))
}
/// Try to allocate a specific port. Returns Some(port) if available, None if taken/restricted.
pub fn try_alloc(&mut self, port: u16, ssl: bool) -> Option<u16> {
if is_restricted(port) || self.0.contains_key(&port) {
return None;
}
self.0.insert(port, ssl);
Some(port)
}
pub fn set_ssl(&mut self, port: u16, ssl: bool) {
self.0.insert(port, ssl);
}
/// Returns whether a given allocated port is SSL.
pub fn is_ssl(&self, port: u16) -> bool {
self.0.get(&port).copied().unwrap_or(false)
}
pub fn free(&mut self, ports: impl IntoIterator<Item = u16>) {
for port in ports {
self.0.remove(&port);
}
}
}
pub fn forward_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new().subcommand(
"dump-table",
from_fn_async(
|ctx: RpcContext| async move { ctx.net_controller.forward.dump_table().await },
)
.with_display_serializable()
.with_custom_display_fn(|HandlerArgs { params, .. }, res| {
use prettytable::*;
if let Some(format) = params.format {
return display_serializable(format, res);
}
let mut table = Table::new();
table.add_row(row![bc => "FROM", "TO", "REQS"]);
for (external, target) in res.0 {
table.add_row(row![external, target.target, target.reqs]);
}
table.print_tty(false)?;
Ok(())
})
.with_call_remote::<CliContext>(),
)
}
struct ForwardMapping {
source: SocketAddrV4,
target: SocketAddrV4,
target_prefix: u8,
src_filter: Option<IpNet>,
rc: Weak<()>,
}
#[derive(Default)]
struct PortForwardState {
mappings: BTreeMap<SocketAddrV4, ForwardMapping>, // source -> target
}
impl PortForwardState {
async fn add_forward(
&mut self,
source: SocketAddrV4,
target: SocketAddrV4,
target_prefix: u8,
src_filter: Option<IpNet>,
) -> Result<Arc<()>, Error> {
if let Some(existing) = self.mappings.get_mut(&source) {
if existing.target == target && existing.src_filter == src_filter {
if let Some(existing_rc) = existing.rc.upgrade() {
return Ok(existing_rc);
} else {
let rc = Arc::new(());
existing.rc = Arc::downgrade(&rc);
return Ok(rc);
}
} else {
// Different target or src_filter, need to remove old and add new
if let Some(mapping) = self.mappings.remove(&source) {
unforward(
mapping.source,
mapping.target,
mapping.target_prefix,
mapping.src_filter.as_ref(),
)
.await?;
}
}
}
let rc = Arc::new(());
forward(source, target, target_prefix, src_filter.as_ref()).await?;
self.mappings.insert(
source,
ForwardMapping {
source,
target,
target_prefix,
src_filter,
rc: Arc::downgrade(&rc),
},
);
Ok(rc)
}
async fn gc(&mut self) -> Result<(), Error> {
let to_remove: Vec<SocketAddrV4> = self
.mappings
.iter()
.filter(|(_, mapping)| mapping.rc.strong_count() == 0)
.map(|(source, _)| *source)
.collect();
for source in to_remove {
if let Some(mapping) = self.mappings.remove(&source) {
unforward(
mapping.source,
mapping.target,
mapping.target_prefix,
mapping.src_filter.as_ref(),
)
.await?;
}
}
Ok(())
}
fn dump(&self) -> BTreeMap<SocketAddrV4, SocketAddrV4> {
self.mappings
.iter()
.filter(|(_, mapping)| mapping.rc.strong_count() > 0)
.map(|(source, mapping)| (*source, mapping.target))
.collect()
}
}
impl Drop for PortForwardState {
fn drop(&mut self) {
if !self.mappings.is_empty() {
let mappings = std::mem::take(&mut self.mappings);
tokio::spawn(async move {
for (_, mapping) in mappings {
unforward(
mapping.source,
mapping.target,
mapping.target_prefix,
mapping.src_filter.as_ref(),
)
.await
.log_err();
}
});
}
}
}
enum PortForwardCommand {
AddForward {
source: SocketAddrV4,
target: SocketAddrV4,
target_prefix: u8,
src_filter: Option<IpNet>,
respond: oneshot::Sender<Result<Arc<()>, Error>>,
},
Gc {
respond: oneshot::Sender<Result<(), Error>>,
},
Dump {
respond: oneshot::Sender<BTreeMap<SocketAddrV4, SocketAddrV4>>,
},
}
pub struct PortForwardController {
req: mpsc::UnboundedSender<PortForwardCommand>,
_thread: NonDetachingJoinHandle<()>,
}
pub async fn add_iptables_rule(nat: bool, undo: bool, args: &[&str]) -> Result<(), Error> {
let mut cmd = Command::new("iptables");
if nat {
cmd.arg("-t").arg("nat");
}
let exists = cmd
.arg("-C")
.args(args)
.invoke(ErrorKind::Network)
.await
.is_ok();
if undo != !exists {
let mut cmd = Command::new("iptables");
if nat {
cmd.arg("-t").arg("nat");
}
if undo {
cmd.arg("-D");
} else {
cmd.arg("-A");
}
cmd.args(args).invoke(ErrorKind::Network).await?;
}
Ok(())
}
impl PortForwardController {
pub fn new() -> Self {
let (req_send, mut req_recv) = mpsc::unbounded_channel::<PortForwardCommand>();
let thread = NonDetachingJoinHandle::from(tokio::spawn(async move {
while let Err(e) = async {
Command::new("iptables")
.arg("-P")
.arg("FORWARD")
.arg("DROP")
.invoke(ErrorKind::Network)
.await?;
add_iptables_rule(
false,
false,
&[
"FORWARD",
"-m",
"state",
"--state",
"ESTABLISHED,RELATED",
"-j",
"ACCEPT",
],
)
.await?;
Command::new("sysctl")
.arg("-w")
.arg("net.ipv4.ip_forward=1")
.invoke(ErrorKind::Network)
.await?;
Ok::<_, Error>(())
}
.await
{
tracing::error!(
"{}",
t!(
"net.forward.error-initializing-controller",
error = format!("{e:#}")
)
);
tracing::debug!("{e:?}");
tokio::time::sleep(Duration::from_secs(5)).await;
}
let mut state = PortForwardState::default();
while let Some(cmd) = req_recv.recv().await {
match cmd {
PortForwardCommand::AddForward {
source,
target,
target_prefix,
src_filter,
respond,
} => {
let result = state
.add_forward(source, target, target_prefix, src_filter)
.await;
respond.send(result).ok();
}
PortForwardCommand::Gc { respond } => {
let result = state.gc().await;
respond.send(result).ok();
}
PortForwardCommand::Dump { respond } => {
respond.send(state.dump()).ok();
}
}
}
}));
Self {
req: req_send,
_thread: thread,
}
}
pub async fn add_forward(
&self,
source: SocketAddrV4,
target: SocketAddrV4,
target_prefix: u8,
src_filter: Option<IpNet>,
) -> Result<Arc<()>, Error> {
let (send, recv) = oneshot::channel();
self.req
.send(PortForwardCommand::AddForward {
source,
target,
target_prefix,
src_filter,
respond: send,
})
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)?
}
pub async fn gc(&self) -> Result<(), Error> {
let (send, recv) = oneshot::channel();
self.req
.send(PortForwardCommand::Gc { respond: send })
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)?
}
pub async fn dump(&self) -> Result<BTreeMap<SocketAddrV4, SocketAddrV4>, Error> {
let (send, recv) = oneshot::channel();
self.req
.send(PortForwardCommand::Dump { respond: send })
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)
}
}
struct InterfaceForwardRequest {
external: u16,
target: SocketAddrV4,
target_prefix: u8,
reqs: ForwardRequirements,
rc: Arc<()>,
}
#[derive(Clone)]
struct InterfaceForwardEntry {
external: u16,
targets: BTreeMap<ForwardRequirements, (SocketAddrV4, u8, Weak<()>)>,
// Maps source SocketAddr -> strong reference for the forward created in PortForwardController
forwards: BTreeMap<SocketAddrV4, Arc<()>>,
}
impl IdOrdItem for InterfaceForwardEntry {
type Key<'a> = u16;
fn key(&self) -> Self::Key<'_> {
self.external
}
iddqd::id_upcast!();
}
impl InterfaceForwardEntry {
fn new(external: u16) -> Self {
Self {
external,
targets: BTreeMap::new(),
forwards: BTreeMap::new(),
}
}
async fn update(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
port_forward: &PortForwardController,
) -> Result<(), Error> {
let mut keep = BTreeSet::<SocketAddrV4>::new();
for (gw_id, info) in ip_info.iter() {
if let Some(ip_info) = &info.ip_info {
for subnet in ip_info.subnets.iter() {
if let IpAddr::V4(ip) = subnet.addr() {
let addr = SocketAddrV4::new(ip, self.external);
if keep.contains(&addr) {
continue;
}
for (reqs, (target, target_prefix, rc)) in self.targets.iter() {
if rc.strong_count() == 0 {
continue;
}
if !reqs.secure && !info.secure() {
continue;
}
let src_filter = if reqs.public_gateways.contains(gw_id) {
None
} else if reqs.private_ips.contains(&IpAddr::V4(ip)) {
Some(subnet.trunc())
} else {
continue;
};
keep.insert(addr);
let fwd_rc = port_forward
.add_forward(addr, *target, *target_prefix, src_filter)
.await?;
self.forwards.insert(addr, fwd_rc);
break;
}
}
}
}
}
// Remove forwards that should no longer exist (drops the strong references)
self.forwards.retain(|addr, _| keep.contains(addr));
Ok(())
}
async fn update_request(
&mut self,
InterfaceForwardRequest {
external,
target,
target_prefix,
reqs,
mut rc,
}: InterfaceForwardRequest,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
port_forward: &PortForwardController,
) -> Result<Arc<()>, Error> {
if external != self.external {
return Err(Error::new(
eyre!("{}", t!("net.forward.mismatched-external-port")),
ErrorKind::InvalidRequest,
));
}
let entry = self
.targets
.entry(reqs)
.or_insert_with(|| (target, target_prefix, Arc::downgrade(&rc)));
if entry.0 != target {
entry.0 = target;
entry.1 = target_prefix;
entry.2 = Arc::downgrade(&rc);
}
if let Some(existing) = entry.2.upgrade() {
rc = existing;
} else {
entry.2 = Arc::downgrade(&rc);
}
self.update(ip_info, port_forward).await?;
Ok(rc)
}
async fn gc(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
port_forward: &PortForwardController,
) -> Result<(), Error> {
self.targets.retain(|_, (_, _, rc)| rc.strong_count() > 0);
self.update(ip_info, port_forward).await
}
}
struct InterfaceForwardState {
port_forward: PortForwardController,
state: IdOrdMap<InterfaceForwardEntry>,
}
impl InterfaceForwardState {
fn new(port_forward: PortForwardController) -> Self {
Self {
port_forward,
state: IdOrdMap::new(),
}
}
}
impl InterfaceForwardState {
async fn handle_request(
&mut self,
request: InterfaceForwardRequest,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
) -> Result<Arc<()>, Error> {
self.state
.entry(request.external)
.or_insert_with(|| InterfaceForwardEntry::new(request.external))
.update_request(request, ip_info, &self.port_forward)
.await
}
async fn sync(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
) -> Result<(), Error> {
for mut entry in self.state.iter_mut() {
entry.gc(ip_info, &self.port_forward).await?;
}
self.port_forward.gc().await
}
}
fn err_has_exited<T>(_: T) -> Error {
Error::new(
eyre!("{}", t!("net.forward.controller-thread-exited")),
ErrorKind::Unknown,
)
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ForwardTable(pub BTreeMap<u16, ForwardTarget>);
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ForwardTarget {
pub target: SocketAddrV4,
pub target_prefix: u8,
pub reqs: String,
}
impl From<&InterfaceForwardState> for ForwardTable {
fn from(value: &InterfaceForwardState) -> Self {
Self(
value
.state
.iter()
.flat_map(|entry| {
entry
.targets
.iter()
.filter(|(_, (_, _, rc))| rc.strong_count() > 0)
.map(|(reqs, (target, target_prefix, _))| {
(
entry.external,
ForwardTarget {
target: *target,
target_prefix: *target_prefix,
reqs: format!("{reqs}"),
},
)
})
})
.collect(),
)
}
}
enum InterfaceForwardCommand {
Forward(
InterfaceForwardRequest,
oneshot::Sender<Result<Arc<()>, Error>>,
),
Sync(oneshot::Sender<Result<(), Error>>),
DumpTable(oneshot::Sender<ForwardTable>),
}
pub struct InterfacePortForwardController {
req: mpsc::UnboundedSender<InterfaceForwardCommand>,
_thread: NonDetachingJoinHandle<()>,
}
impl InterfacePortForwardController {
pub fn new(mut ip_info: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>) -> Self {
let port_forward = PortForwardController::new();
let (req_send, mut req_recv) = mpsc::unbounded_channel::<InterfaceForwardCommand>();
let thread = NonDetachingJoinHandle::from(tokio::spawn(async move {
let mut state = InterfaceForwardState::new(port_forward);
let mut interfaces = ip_info.read_and_mark_seen();
loop {
tokio::select! {
msg = req_recv.recv() => {
if let Some(cmd) = msg {
match cmd {
InterfaceForwardCommand::Forward(req, re) => {
re.send(state.handle_request(req, &interfaces).await).ok()
}
InterfaceForwardCommand::Sync(re) => {
re.send(state.sync(&interfaces).await).ok()
}
InterfaceForwardCommand::DumpTable(re) => {
re.send((&state).into()).ok()
}
};
} else {
break;
}
}
_ = ip_info.changed() => {
interfaces = ip_info.read();
state.sync(&interfaces).await.log_err();
}
}
}
}));
Self {
req: req_send,
_thread: thread,
}
}
pub async fn add(
&self,
external: u16,
reqs: ForwardRequirements,
target: SocketAddrV4,
target_prefix: u8,
) -> Result<Arc<()>, Error> {
let rc = Arc::new(());
let (send, recv) = oneshot::channel();
self.req
.send(InterfaceForwardCommand::Forward(
InterfaceForwardRequest {
external,
target,
target_prefix,
reqs,
rc,
},
send,
))
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)?
}
pub async fn gc(&self) -> Result<(), Error> {
let (send, recv) = oneshot::channel();
self.req
.send(InterfaceForwardCommand::Sync(send))
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)?
}
pub async fn dump_table(&self) -> Result<ForwardTable, Error> {
let (req, res) = oneshot::channel();
self.req
.send(InterfaceForwardCommand::DumpTable(req))
.map_err(err_has_exited)?;
res.await.map_err(err_has_exited)
}
}
async fn forward(
source: SocketAddrV4,
target: SocketAddrV4,
target_prefix: u8,
src_filter: Option<&IpNet>,
) -> Result<(), Error> {
let mut cmd = Command::new("/usr/lib/startos/scripts/forward-port");
cmd.env("sip", source.ip().to_string())
.env("dip", target.ip().to_string())
.env("dprefix", target_prefix.to_string())
.env("sport", source.port().to_string())
.env("dport", target.port().to_string())
.env(
"bridge_subnet",
Ipv4Net::new(HOST_IP.into(), 24)
.with_kind(ErrorKind::ParseNetAddress)?
.trunc()
.to_string(),
);
if let Some(subnet) = src_filter {
cmd.env("src_subnet", subnet.to_string());
}
cmd.invoke(ErrorKind::Network).await?;
Ok(())
}
async fn unforward(
source: SocketAddrV4,
target: SocketAddrV4,
target_prefix: u8,
src_filter: Option<&IpNet>,
) -> Result<(), Error> {
let mut cmd = Command::new("/usr/lib/startos/scripts/forward-port");
cmd.env("UNDO", "1")
.env("sip", source.ip().to_string())
.env("dip", target.ip().to_string())
.env("dprefix", target_prefix.to_string())
.env("sport", source.port().to_string())
.env("dport", target.port().to_string());
if let Some(subnet) = src_filter {
cmd.env("src_subnet", subnet.to_string());
}
cmd.invoke(ErrorKind::Network).await?;
Ok(())
}