This commit is contained in:
Aiden McClelland
2025-10-20 18:05:57 -06:00
parent 716bf920f5
commit 40b00bae75
26 changed files with 736 additions and 401 deletions

View File

@@ -85,16 +85,16 @@ async-acme = { version = "0.6.0", git = "https://github.com/dr-bonez/async-acme.
"use_rustls",
"use_tokio",
] }
async-compression = { version = "0.4.4", features = [
async-compression = { version = "0.4.32", features = [
"gzip",
"brotli",
"zstd",
"tokio",
] }
async-stream = "0.3.5"
async-trait = "0.1.74"
axum = { version = "0.8.4", features = ["ws"] }
barrage = "0.2.3"
backhand = "0.21.0"
backtrace-on-stack-overflow = { version = "0.3.0", optional = true }
base32 = "0.5.0"
base64 = "0.22.1"
@@ -153,6 +153,7 @@ id-pool = { version = "0.2.2", default-features = false, features = [
"serde",
"u16",
] }
iddqd = "0.3.14"
imbl = { version = "6", features = ["serde", "small-chunks"] }
imbl-value = { version = "0.4.3", features = ["ts-rs"] }
include_dir = { version = "0.7.3", features = ["metadata"] }
@@ -283,6 +284,7 @@ unix-named-pipe = "0.2.0"
url = { version = "2.4.1", features = ["serde"] }
urlencoding = "2.1.3"
uuid = { version = "1.4.1", features = ["v4"] }
visit-rs = "0.1.1"
x25519-dalek = { version = "2.0.1", features = ["static_secrets"] }
zbus = "5.1.1"
zeroize = "1.6.0"

View File

@@ -250,30 +250,6 @@ impl NetworkInterfaceInfo {
}
(&*LO, &*LOOPBACK)
}
pub fn lxc_bridge() -> (&'static GatewayId, &'static Self) {
lazy_static! {
static ref LXCBR0: GatewayId =
GatewayId::from(InternedString::intern(START9_BRIDGE_IFACE));
static ref LXC_BRIDGE: NetworkInterfaceInfo = NetworkInterfaceInfo {
name: Some(InternedString::from_static("LXC Bridge Interface")),
public: Some(false),
secure: Some(true),
ip_info: Some(IpInfo {
name: START9_BRIDGE_IFACE.into(),
scope_id: 0,
device_type: None,
subnets: [IpNet::new(HOST_IP.into(), 24).unwrap()]
.into_iter()
.collect(),
lan_ip: [IpAddr::from(HOST_IP)].into_iter().collect(),
wan_ip: None,
ntp_servers: Default::default(),
dns_servers: Default::default(),
}),
};
}
(&*LXCBR0, &*LXC_BRIDGE)
}
pub fn public(&self) -> bool {
self.public.unwrap_or_else(|| {
!self.ip_info.as_ref().map_or(true, |ip_info| {
@@ -339,7 +315,9 @@ pub struct IpInfo {
pub enum NetworkInterfaceType {
Ethernet,
Wireless,
Bridge,
Wireguard,
Loopback,
}
#[derive(Debug, Deserialize, Serialize, HasModel, TS)]

View File

@@ -415,10 +415,7 @@ impl Resolver {
{
if let Some(res) = self.net_iface.peek(|i| {
i.values()
.chain([
NetworkInterfaceInfo::loopback().1,
NetworkInterfaceInfo::lxc_bridge().1,
])
.chain([NetworkInterfaceInfo::loopback().1])
.filter_map(|i| i.ip_info.as_ref())
.find(|i| i.subnets.iter().any(|s| s.contains(&src)))
.map(|ip_info| {

View File

@@ -5,6 +5,7 @@ use std::sync::{Arc, Weak};
use futures::channel::oneshot;
use helpers::NonDetachingJoinHandle;
use id_pool::IdPool;
use iddqd::{IdOrdItem, IdOrdMap};
use imbl::OrdMap;
use models::GatewayId;
use rpc_toolkit::{from_fn_async, Context, HandlerArgs, HandlerExt, ParentHandler};
@@ -14,7 +15,7 @@ use tokio::sync::mpsc;
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::NetworkInterfaceInfo;
use crate::net::gateway::{DynInterfaceFilter, InterfaceFilter, SecureFilter};
use crate::net::gateway::{DynInterfaceFilter, InterfaceFilter};
use crate::net::utils::ipv6_is_link_local;
use crate::prelude::*;
use crate::util::serde::{display_serializable, HandlerExtSerde};
@@ -60,17 +61,10 @@ pub fn forward_api<C: Context>() -> ParentHandler<C> {
}
let mut table = Table::new();
table.add_row(row![bc => "FROM", "TO", "FILTER / GATEWAY"]);
table.add_row(row![bc => "FROM", "TO", "FILTER"]);
for (external, target) in res.0 {
table.add_row(row![external, target.target, target.filter]);
for (source, gateway) in target.gateways {
table.add_row(row![
format!("{}:{}", source, external),
target.target,
gateway
]);
}
}
table.print_tty(false)?;
@@ -85,41 +79,43 @@ struct ForwardRequest {
external: u16,
target: SocketAddr,
filter: DynInterfaceFilter,
rc: Weak<()>,
rc: Arc<()>,
}
#[derive(Clone)]
struct ForwardEntry {
external: u16,
target: SocketAddr,
prev_filter: DynInterfaceFilter,
forwards: BTreeMap<SocketAddr, GatewayId>,
rc: Weak<()>,
filter: BTreeMap<DynInterfaceFilter, (SocketAddr, Weak<()>)>,
forwards: BTreeMap<SocketAddr, (GatewayId, SocketAddr)>,
}
impl IdOrdItem for ForwardEntry {
type Key<'a> = u16;
fn key(&self) -> Self::Key<'_> {
self.external
}
iddqd::id_upcast!();
}
impl ForwardEntry {
fn new(external: u16, target: SocketAddr, rc: Weak<()>) -> Self {
fn new(external: u16) -> Self {
Self {
external,
target,
prev_filter: false.into_dyn(),
filter: BTreeMap::new(),
forwards: BTreeMap::new(),
rc,
}
}
fn take(&mut self) -> Self {
Self {
external: self.external,
target: self.target,
prev_filter: std::mem::replace(&mut self.prev_filter, false.into_dyn()),
filter: std::mem::take(&mut self.filter),
forwards: std::mem::take(&mut self.forwards),
rc: self.rc.clone(),
}
}
async fn destroy(mut self) -> Result<(), Error> {
while let Some((source, interface)) = self.forwards.pop_first() {
unforward(interface.as_str(), source, self.target).await?;
while let Some((source, (interface, target))) = self.forwards.pop_first() {
unforward(interface.as_str(), source, target).await?;
}
Ok(())
}
@@ -127,38 +123,37 @@ impl ForwardEntry {
async fn update(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
filter: Option<DynInterfaceFilter>,
) -> Result<(), Error> {
if self.rc.strong_count() == 0 {
return self.take().destroy().await;
}
let filter_ref = filter.as_ref().unwrap_or(&self.prev_filter);
let mut keep = BTreeSet::<SocketAddr>::new();
for (iface, info) in ip_info
.iter()
// .chain([NetworkInterfaceInfo::loopback()])
.filter(|(id, info)| filter_ref.filter(*id, *info))
{
if let Some(ip_info) = &info.ip_info {
for ipnet in &ip_info.subnets {
let addr = match ipnet.addr() {
IpAddr::V6(ip6) => SocketAddrV6::new(
ip6,
self.external,
0,
if ipv6_is_link_local(ip6) {
ip_info.scope_id
} else {
0
},
)
.into(),
ip => SocketAddr::new(ip, self.external),
};
keep.insert(addr);
if !self.forwards.contains_key(&addr) {
forward(iface.as_str(), addr, self.target).await?;
self.forwards.insert(addr, iface.clone());
for (iface, info) in ip_info.iter() {
if let Some(target) = self
.filter
.iter()
.filter(|(_, (_, rc))| rc.strong_count() > 0)
.find(|(filter, _)| filter.filter(iface, info))
.map(|(_, (target, _))| *target)
{
if let Some(ip_info) = &info.ip_info {
for ipnet in &ip_info.subnets {
let addr = match ipnet.addr() {
IpAddr::V6(ip6) => SocketAddrV6::new(
ip6,
self.external,
0,
if ipv6_is_link_local(ip6) {
ip_info.scope_id
} else {
0
},
)
.into(),
ip => SocketAddr::new(ip, self.external),
};
keep.insert(addr);
if !self.forwards.contains_key(&addr) {
forward(iface.as_str(), addr, target).await?;
self.forwards.insert(addr, (iface.clone(), target));
}
}
}
}
@@ -170,13 +165,10 @@ impl ForwardEntry {
.filter(|a| !keep.contains(a))
.collect::<Vec<_>>();
for rm in rm {
if let Some((source, interface)) = self.forwards.remove_entry(&rm) {
unforward(interface.as_str(), source, self.target).await?;
if let Some((source, (interface, target))) = self.forwards.remove_entry(&rm) {
unforward(interface.as_str(), source, target).await?;
}
}
if let Some(filter) = filter {
self.prev_filter = filter;
}
Ok(())
}
@@ -186,20 +178,34 @@ impl ForwardEntry {
external,
target,
filter,
rc,
mut rc,
}: ForwardRequest,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
) -> Result<(), Error> {
if external != self.external || target != self.target {
self.take().destroy().await?;
*self = Self::new(external, target, rc);
self.update(ip_info, Some(filter)).await?;
} else {
self.rc = rc;
self.update(ip_info, Some(filter).filter(|f| f != &self.prev_filter))
.await?;
) -> Result<Arc<()>, Error> {
if external != self.external {
return Err(Error::new(
eyre!("Mismatched external port in ForwardEntry"),
ErrorKind::InvalidRequest,
));
}
Ok(())
let entry = self
.filter
.entry(filter)
.or_insert_with(|| (target, Arc::downgrade(&rc)));
if entry.0 != target {
entry.0 = target;
entry.1 = Arc::downgrade(&rc);
}
if let Some(existing) = entry.1.upgrade() {
rc = existing;
} else {
entry.1 = Arc::downgrade(&rc);
}
self.update(ip_info).await?;
Ok(rc)
}
}
impl Drop for ForwardEntry {
@@ -215,17 +221,17 @@ impl Drop for ForwardEntry {
#[derive(Default, Clone)]
struct ForwardState {
state: BTreeMap<u16, ForwardEntry>,
state: IdOrdMap<ForwardEntry>,
}
impl ForwardState {
async fn handle_request(
&mut self,
request: ForwardRequest,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
) -> Result<(), Error> {
) -> Result<Arc<()>, Error> {
self.state
.entry(request.external)
.or_insert_with(|| ForwardEntry::new(request.external, request.target, Weak::new()))
.or_insert_with(|| ForwardEntry::new(request.external))
.update_request(request, ip_info)
.await
}
@@ -233,10 +239,9 @@ impl ForwardState {
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
) -> Result<(), Error> {
for entry in self.state.values_mut() {
entry.update(ip_info, None).await?;
for mut entry in self.state.iter_mut() {
entry.update(ip_info).await?;
}
self.state.retain(|_, fwd| fwd.rc.strong_count() > 0);
Ok(())
}
}
@@ -254,7 +259,6 @@ pub struct ForwardTable(pub BTreeMap<u16, ForwardTarget>);
pub struct ForwardTarget {
pub target: SocketAddr,
pub filter: String,
pub gateways: BTreeMap<SocketAddr, GatewayId>,
}
impl From<&ForwardState> for ForwardTable {
fn from(value: &ForwardState) -> Self {
@@ -262,15 +266,20 @@ impl From<&ForwardState> for ForwardTable {
value
.state
.iter()
.map(|(external, entry)| {
(
*external,
ForwardTarget {
target: entry.target,
filter: format!("{:?}", entry.prev_filter),
gateways: entry.forwards.clone(),
},
)
.flat_map(|entry| {
entry
.filter
.iter()
.filter(|(_, (_, rc))| rc.strong_count() > 0)
.map(|(filter, (target, _))| {
(
entry.external,
ForwardTarget {
target: *target,
filter: format!("{:?}", filter),
},
)
})
})
.collect(),
)
@@ -278,13 +287,15 @@ impl From<&ForwardState> for ForwardTable {
}
enum ForwardCommand {
Forward(ForwardRequest, oneshot::Sender<Result<(), Error>>),
Forward(ForwardRequest, oneshot::Sender<Result<Arc<()>, Error>>),
Sync(oneshot::Sender<Result<(), Error>>),
DumpTable(oneshot::Sender<ForwardTable>),
}
#[test]
fn test() {
use crate::net::gateway::SecureFilter;
assert_ne!(
false.into_dyn(),
SecureFilter { secure: false }.into_dyn().into_dyn()
@@ -340,13 +351,13 @@ impl PortForwardController {
external,
target,
filter,
rc: Arc::downgrade(&rc),
rc,
},
send,
))
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)?.map(|_| rc)
recv.await.map_err(err_has_exited)?
}
pub async fn gc(&self) -> Result<(), Error> {
let (send, recv) = oneshot::channel();

View File

@@ -585,18 +585,19 @@ async fn watch_ip(
loop {
until
.run(async {
let external = active_connection_proxy.state_flags().await? & 0x80 != 0;
if external {
return Ok(());
}
let device_type = match device_proxy.device_type().await? {
1 => Some(NetworkInterfaceType::Ethernet),
2 => Some(NetworkInterfaceType::Wireless),
13 => Some(NetworkInterfaceType::Bridge),
29 => Some(NetworkInterfaceType::Wireguard),
32 => Some(NetworkInterfaceType::Loopback),
_ => None,
};
if device_type == Some(NetworkInterfaceType::Loopback) {
return Ok(());
}
let name = InternedString::from(active_connection_proxy.id().await?);
let dhcp4_config = active_connection_proxy.dhcp4_config().await?;
@@ -787,13 +788,7 @@ impl NetworkInterfaceWatcher {
watch_activated: impl IntoIterator<Item = GatewayId>,
) -> Self {
let ip_info = Watch::new(OrdMap::new());
let activated = Watch::new(
watch_activated
.into_iter()
.chain([NetworkInterfaceInfo::lxc_bridge().0.clone()])
.map(|k| (k, false))
.collect(),
);
let activated = Watch::new(watch_activated.into_iter().map(|k| (k, false)).collect());
Self {
activated: activated.clone(),
ip_info: ip_info.clone(),
@@ -1384,14 +1379,12 @@ impl ListenerMap {
fn update(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
lxc_bridge: bool,
filter: &impl InterfaceFilter,
) -> Result<(), Error> {
let mut keep = BTreeSet::<SocketAddr>::new();
for (_, info) in ip_info
.iter()
.chain([NetworkInterfaceInfo::loopback()])
.chain(Some(NetworkInterfaceInfo::lxc_bridge()).filter(|_| lxc_bridge))
.filter(|(id, info)| filter.filter(*id, *info))
{
if let Some(ip_info) = &info.ip_info {
@@ -1466,10 +1459,7 @@ pub fn lookup_info_by_addr(
) -> Option<(&GatewayId, &NetworkInterfaceInfo)> {
ip_info
.iter()
.chain([
NetworkInterfaceInfo::loopback(),
NetworkInterfaceInfo::lxc_bridge(),
])
.chain([NetworkInterfaceInfo::loopback()])
.find(|(_, i)| {
i.ip_info
.as_ref()
@@ -1495,16 +1485,10 @@ impl NetworkInterfaceListener {
filter: &impl InterfaceFilter,
) -> Poll<Result<Accepted, Error>> {
while self.ip_info.poll_changed(cx).is_ready()
|| self.activated.poll_changed(cx).is_ready()
|| !DynInterfaceFilterT::eq(&self.listeners.prev_filter, filter.as_any())
{
let lxc_bridge = self.activated.peek(|a| {
a.get(NetworkInterfaceInfo::lxc_bridge().0)
.copied()
.unwrap_or_default()
});
self.ip_info
.peek_and_mark_seen(|ip_info| self.listeners.update(ip_info, lxc_bridge, filter))?;
.peek_and_mark_seen(|ip_info| self.listeners.update(ip_info, filter))?;
}
self.listeners.poll_accept(cx)
}
@@ -1562,9 +1546,12 @@ impl SelfContainedNetworkInterfaceListener {
pub fn bind(port: u16) -> Self {
let ip_info = Watch::new(OrdMap::new());
let activated = Watch::new(
[(NetworkInterfaceInfo::lxc_bridge().0.clone(), false)]
.into_iter()
.collect(),
[(
GatewayId::from(InternedString::from(START9_BRIDGE_IFACE)),
false,
)]
.into_iter()
.collect(),
);
let _watch_thread = tokio::spawn(watcher(ip_info.clone(), activated.clone())).into();
Self {

View File

@@ -6,7 +6,7 @@ use color_eyre::eyre::eyre;
use imbl::{vector, OrdMap};
use imbl_value::InternedString;
use ipnet::IpNet;
use models::{HostId, OptionExt, PackageId};
use models::{GatewayId, HostId, OptionExt, PackageId};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tracing::instrument;
@@ -16,7 +16,7 @@ use crate::db::model::Database;
use crate::error::ErrorCollection;
use crate::hostname::Hostname;
use crate::net::dns::DnsController;
use crate::net::forward::PortForwardController;
use crate::net::forward::{PortForwardController, START9_BRIDGE_IFACE};
use crate::net::gateway::{
AndFilter, DynInterfaceFilter, IdFilter, InterfaceFilter, NetworkInterfaceController, OrFilter,
PublicFilter, SecureFilter,
@@ -283,9 +283,9 @@ impl NetServiceData {
IdFilter(
NetworkInterfaceInfo::loopback().0.clone(),
),
IdFilter(
NetworkInterfaceInfo::lxc_bridge().0.clone(),
),
IdFilter(GatewayId::from(InternedString::from(
START9_BRIDGE_IFACE,
))),
)
.into_dyn(),
acme: None,

View File

@@ -3,12 +3,10 @@ use tokio::io::{AsyncSeek, AsyncWrite};
use crate::prelude::*;
use crate::util::io::TrackingIO;
#[async_trait::async_trait]
pub trait Sink: AsyncWrite + Unpin + Send {
async fn current_position(&mut self) -> Result<u64, Error>;
fn current_position(&mut self) -> impl Future<Output = Result<u64, Error>> + Send + '_;
}
#[async_trait::async_trait]
impl<S: AsyncWrite + AsyncSeek + Unpin + Send> Sink for S {
async fn current_position(&mut self) -> Result<u64, Error> {
use tokio::io::AsyncSeekExt;
@@ -17,7 +15,6 @@ impl<S: AsyncWrite + AsyncSeek + Unpin + Send> Sink for S {
}
}
#[async_trait::async_trait]
impl<W: AsyncWrite + Unpin + Send> Sink for TrackingIO<W> {
async fn current_position(&mut self) -> Result<u64, Error> {
Ok(self.position())

View File

@@ -1,14 +1,17 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use clap::Parser;
use imbl_value::InternedString;
use ipnet::Ipv4Net;
use models::GatewayId;
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerArgs, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use crate::context::CliContext;
use crate::net::gateway::{IdFilter, InterfaceFilter};
use crate::prelude::*;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::db::GatewayPort;
use crate::tunnel::wg::{ClientConfig, WgConfig, WgSubnetClients, WgSubnetConfig};
use crate::util::serde::{display_serializable, HandlerExtSerde};
@@ -27,26 +30,26 @@ pub fn tunnel_api<C: Context>() -> ParentHandler<C> {
"subnet",
subnet_api::<C>().with_about("Add, remove, or modify subnets"),
)
// .subcommand(
// "port-forward",
// ParentHandler::<C>::new()
// .subcommand(
// "add",
// from_fn_async(add_forward)
// .with_metadata("sync_db", Value::Bool(true))
// .no_display()
// .with_about("Add a new port forward")
// .with_call_remote::<CliContext>(),
// )
// .subcommand(
// "remove",
// from_fn_async(remove_forward)
// .with_metadata("sync_db", Value::Bool(true))
// .no_display()
// .with_about("Remove a port forward")
// .with_call_remote::<CliContext>(),
// ),
// )
.subcommand(
"port-forward",
ParentHandler::<C>::new()
.subcommand(
"add",
from_fn_async(add_forward)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Add a new port forward")
.with_call_remote::<CliContext>(),
)
.subcommand(
"remove",
from_fn_async(remove_forward)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Remove a port forward")
.with_call_remote::<CliContext>(),
),
)
}
#[derive(Deserialize, Serialize, Parser)]
@@ -345,3 +348,53 @@ pub async fn show_config(
(wan_addr, wg.as_port().de()?).into(),
))
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct AddPortForwardParams {
source: GatewayPort,
target: SocketAddrV4,
}
pub async fn add_forward(
ctx: TunnelContext,
AddPortForwardParams { source, target }: AddPortForwardParams,
) -> Result<(), Error> {
ctx.db
.mutate(|db| db.as_port_forwards_mut().insert(&source, &target))
.await
.result?;
let rc = ctx
.forward
.add(
source.1,
IdFilter(source.0.clone()).into_dyn(),
target.into(),
)
.await?;
ctx.active_forwards.mutate(|m| {
m.insert(source, rc);
});
Ok(())
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct RemovePortForwardParams {
source: GatewayPort,
}
pub async fn remove_forward(
ctx: TunnelContext,
RemovePortForwardParams { source, .. }: RemovePortForwardParams,
) -> Result<(), Error> {
ctx.db
.mutate(|db| db.as_port_forwards_mut().remove(&source))
.await
.result?;
if let Some(rc) = ctx.active_forwards.mutate(|m| m.remove(&source)) {
drop(rc);
ctx.forward.gc().await?;
}
Ok(())
}

View File

@@ -1,4 +1,4 @@
use std::collections::BTreeSet;
use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::ops::Deref;
use std::path::{Path, PathBuf};
@@ -10,6 +10,7 @@ use helpers::NonDetachingJoinHandle;
use http::HeaderMap;
use imbl::OrdMap;
use imbl_value::InternedString;
use models::GatewayId;
use patch_db::PatchDb;
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{CallRemote, Context, Empty};
@@ -22,12 +23,13 @@ use url::Url;
use crate::auth::Sessions;
use crate::context::config::ContextConfig;
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::NetworkInterfaceType;
use crate::middleware::auth::AuthContext;
use crate::net::forward::PortForwardController;
use crate::net::gateway::NetworkInterfaceWatcher;
use crate::net::gateway::{IdFilter, InterfaceFilter, NetworkInterfaceWatcher};
use crate::prelude::*;
use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations};
use crate::tunnel::db::TunnelDatabase;
use crate::tunnel::db::{GatewayPort, TunnelDatabase};
use crate::tunnel::TUNNEL_DEFAULT_PORT;
use crate::util::io::read_file_to_string;
use crate::util::sync::SyncMutex;
@@ -73,6 +75,7 @@ pub struct TunnelContextSeed {
pub ephemeral_sessions: SyncMutex<Sessions>,
pub net_iface: NetworkInterfaceWatcher,
pub forward: PortForwardController,
pub active_forwards: SyncMutex<BTreeMap<GatewayPort, Arc<()>>>,
pub masquerade_thread: NonDetachingJoinHandle<()>,
pub shutdown: Sender<()>,
}
@@ -114,7 +117,17 @@ impl TunnelContext {
let mut masquerade_net_iface = net_iface.subscribe();
let masquerade_thread = tokio::spawn(async move {
loop {
for iface in masquerade_net_iface.peek(|i| i.keys().cloned().collect::<Vec<_>>()) {
for iface in masquerade_net_iface.peek(|i| {
i.iter()
.filter(|(_, info)| {
dbg!(info).ip_info.as_ref().map_or(false, |i| {
dbg!(i).device_type != Some(NetworkInterfaceType::Wireguard)
})
})
.map(|(name, _)| name)
.cloned()
.collect::<Vec<_>>()
}) {
if Command::new("iptables")
.arg("-t")
.arg("nat")
@@ -128,6 +141,7 @@ impl TunnelContext {
.await
.is_err()
{
tracing::info!("Adding masquerade rule for interface {}", iface);
Command::new("iptables")
.arg("-t")
.arg("nat")
@@ -144,11 +158,23 @@ impl TunnelContext {
}
masquerade_net_iface.changed().await;
tracing::info!("Network interfaces changed, updating masquerade rules");
}
})
.into();
db.peek().await.into_wg().de()?.sync().await?;
let peek = db.peek().await;
peek.as_wg().de()?.sync().await?;
let mut active_forwards = BTreeMap::new();
for (from, to) in peek.as_port_forwards().de()?.0 {
active_forwards.insert(
from.clone(),
forward
.add(from.1, IdFilter(from.0).into_dyn(), to.into())
.await?,
);
}
Ok(Self(Arc::new(TunnelContextSeed {
listen,
@@ -164,6 +190,7 @@ impl TunnelContext {
ephemeral_sessions: SyncMutex::new(Sessions::new()),
net_iface,
forward,
active_forwards: SyncMutex::new(active_forwards),
masquerade_thread,
shutdown,
})))

View File

@@ -2,9 +2,12 @@ use std::collections::BTreeMap;
use std::net::SocketAddrV4;
use std::path::PathBuf;
use clap::builder::ValueParserFactory;
use clap::Parser;
use imbl::HashMap;
use imbl_value::InternedString;
use itertools::Itertools;
use models::{FromStrParser, GatewayId};
use patch_db::json_ptr::{JsonPointer, ROOT};
use patch_db::Dump;
use rpc_toolkit::yajrc::RpcError;
@@ -20,7 +23,52 @@ use crate::sign::AnyVerifyingKey;
use crate::tunnel::auth::SignerInfo;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::wg::WgServer;
use crate::util::serde::{apply_expr, HandlerExtSerde};
use crate::util::serde::{apply_expr, deserialize_from_str, serialize_display, HandlerExtSerde};
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct GatewayPort(pub GatewayId, pub u16);
impl std::fmt::Display for GatewayPort {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.0, self.1)
}
}
impl std::str::FromStr for GatewayPort {
type Err = crate::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut parts = s.splitn(2, ':');
let gw: GatewayId = parts
.next()
.ok_or_else(|| Error::new(eyre!("missing gateway id"), ErrorKind::ParseNetAddress))?
.parse()?;
let port: u16 = parts
.next()
.ok_or_else(|| Error::new(eyre!("missing port"), ErrorKind::ParseNetAddress))?
.parse()?;
Ok(GatewayPort(gw, port))
}
}
impl Serialize for GatewayPort {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serialize_display(self, serializer)
}
}
impl<'de> Deserialize<'de> for GatewayPort {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserialize_from_str(deserializer)
}
}
impl ValueParserFactory for GatewayPort {
type Parser = FromStrParser<Self>;
fn value_parser() -> Self::Parser {
FromStrParser::new()
}
}
#[derive(Default, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "camelCase")]
@@ -30,7 +78,20 @@ pub struct TunnelDatabase {
pub password: String,
pub auth_pubkeys: HashMap<AnyVerifyingKey, SignerInfo>,
pub wg: WgServer,
pub port_forwards: BTreeMap<SocketAddrV4, SocketAddrV4>,
pub port_forwards: PortForwards,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct PortForwards(pub BTreeMap<GatewayPort, SocketAddrV4>);
impl Map for PortForwards {
type Key = GatewayPort;
type Value = SocketAddrV4;
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
Self::key_string(key)
}
fn key_string(key: &Self::Key) -> Result<InternedString, Error> {
Ok(InternedString::from_display(key))
}
}
pub fn db_api<C: Context>() -> ParentHandler<C> {

View File

@@ -1 +0,0 @@
use crate::prelude::*;

View File

@@ -1,13 +1,11 @@
use axum::Router;
use futures::future::ready;
use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler, Server};
use rpc_toolkit::Server;
use crate::context::CliContext;
use crate::middleware::auth::Auth;
use crate::middleware::cors::Cors;
use crate::net::static_server::{bad_request, not_found, server_error};
use crate::net::web_server::{Accept, WebServer};
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::tunnel::context::TunnelContext;
@@ -15,7 +13,6 @@ pub mod api;
pub mod auth;
pub mod context;
pub mod db;
pub mod forward;
pub mod wg;
pub const TUNNEL_DEFAULT_PORT: u16 = 5960;

View File

@@ -14,7 +14,7 @@ use std::time::Duration;
use bytes::{Buf, BytesMut};
use clap::builder::ValueParserFactory;
use futures::future::{BoxFuture, Fuse};
use futures::{AsyncSeek, FutureExt, Stream, StreamExt, TryStreamExt};
use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
use helpers::{AtomicFile, NonDetachingJoinHandle};
use inotify::{EventMask, EventStream, Inotify, WatchMask};
use models::FromStrParser;
@@ -22,7 +22,8 @@ use nix::unistd::{Gid, Uid};
use serde::{Deserialize, Serialize};
use tokio::fs::{File, OpenOptions};
use tokio::io::{
duplex, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf, WriteHalf,
duplex, AsyncRead, AsyncReadExt, AsyncSeek, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf,
SeekFrom, WriteHalf,
};
use tokio::net::TcpStream;
use tokio::sync::{Notify, OwnedMutexGuard};

View File

@@ -49,6 +49,7 @@ pub mod net;
pub mod rpc;
pub mod rpc_client;
pub mod serde;
pub mod squashfs;
pub mod sync;
#[derive(Clone, Copy, Debug, ::serde::Deserialize, ::serde::Serialize)]

View File

@@ -0,0 +1,268 @@
use std::io::{Seek, Write};
use std::path::Path;
use std::task::Poll;
use async_compression::codecs::{Encode, ZstdEncoder};
use async_compression::core::util::PartialBuffer;
use futures::{ready, TryStreamExt};
use tokio::io::{AsyncSeek, AsyncWrite};
use visit_rs::{Visit, VisitAsync, VisitFields, VisitFieldsAsync, Visitor};
use crate::prelude::*;
use crate::registry::os::asset::add;
struct SquashfsSerializer<W> {
writer: W,
}
impl<W> Visitor for SquashfsSerializer<W> {
type Result = Result<(), Error>;
}
macro_rules! impl_visit_le {
($($ty:ty),*) => {
$(
impl<W: AsyncWrite + Unpin + Send> VisitAsync<SquashfsSerializer<W>> for $ty {
async fn visit_async(&self, visitor: &mut SquashfsSerializer<W>) -> Result<(), Error> {
use tokio::io::AsyncWriteExt;
visitor.writer.write_all(&self.to_le_bytes()).await?;
Ok(())
}
}
impl<W: Write> Visit<SquashfsSerializer<W>> for $ty {
fn visit(&self, visitor: &mut SquashfsSerializer<W>) -> Result<(), Error> {
visitor.writer.write_all(&self.to_le_bytes())?;
Ok(())
}
}
)*
};
}
impl_visit_le!(u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
#[derive(VisitFields)]
struct Superblock {
magic: u32, // 0x73717368
inode_count: u32,
modification_time: u32, // 0
block_size: u32,
fragment_entry_count: u32,
compression_id: u16, // 6 = zstd
block_log: u16, // log2(block_size)
flags: u16, // 0x0440
id_count: u16,
version_major: u16, // 4
version_minor: u16, // 0
root_inode_ref: u64,
bytes_used: u64,
id_table_start: u64,
xattr_id_table_start: u64,
inode_table_start: u64,
directory_table_start: u64,
fragment_table_start: u64,
export_table_start: u64,
}
impl Default for Superblock {
fn default() -> Self {
Self {
magic: 0x73717368,
inode_count: 0,
modification_time: 0,
block_size: 0,
fragment_entry_count: 0,
compression_id: 6,
block_log: 0,
flags: 0x0440,
id_count: 0,
version_major: 4,
version_minor: 0,
root_inode_ref: 0,
bytes_used: 0,
id_table_start: 0,
xattr_id_table_start: 0,
inode_table_start: 0,
directory_table_start: 0,
fragment_table_start: 0,
export_table_start: 0,
}
}
}
impl<W: AsyncWrite + Unpin + Send> VisitAsync<SquashfsSerializer<W>> for Superblock {
async fn visit_async(&self, visitor: &mut SquashfsSerializer<W>) -> Result<(), Error> {
self.visit_fields_async(visitor).try_collect().await
}
}
impl<W: Write> Visit<SquashfsSerializer<W>> for Superblock {
fn visit(&self, visitor: &mut SquashfsSerializer<W>) -> Result<(), Error> {
self.visit_fields(visitor).collect()
}
}
#[pin_project::pin_project]
pub struct MetadataBlocks<W> {
input: [u8; 8192],
input_flushed: usize,
size: usize,
size_addr: Option<u64>,
end_addr: Option<u64>,
zstd: Option<ZstdEncoder>,
output: PartialBuffer<[u8; 4096]>,
output_flushed: usize,
#[pin]
writer: W,
}
impl<W: Write + Seek> Write for MetadataBlocks<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let n = buf.len().min(self.input.len() - self.size);
self.input[self.size..self.size + n].copy_from_slice(&buf[..n]);
if n < buf.len() {
self.flush()?;
}
Ok(n)
}
fn flush(&mut self) -> std::io::Result<()> {
if self.size > 0 {
if self.size_addr.is_none() {
self.size_addr = Some(self.writer.stream_position()?);
self.output.unwritten_mut()[..2].copy_from_slice(&[0; 2]);
self.output.advance(2);
}
if self.output.written().len() > self.output_flushed {
let n = self
.writer
.write(&self.output.written()[self.output_flushed..])?;
self.output_flushed += n;
}
if self.output.written().len() == self.output_flushed {
self.output_flushed = 0;
self.output.reset();
}
if self.input_flushed < self.size {
if !self.output.unwritten().is_empty() {
let mut input = PartialBuffer::new(&self.input[self.input_flushed..self.size]);
self.zstd
.get_or_insert_with(|| ZstdEncoder::new(22))
.encode(&mut input, &mut self.output)?;
self.input_flushed += input.written().len();
}
} else {
if !self.output.unwritten().is_empty() {
if self.zstd.as_mut().unwrap().finish(&mut self.output)? {
self.zstd = None;
}
}
if self.zstd.is_none() && self.output.written().len() == self.output_flushed {
self.output_flushed = 0;
self.output.reset();
if let Some(addr) = self.size_addr {
let end_addr = if let Some(end_addr) = self.end_addr {
end_addr
} else {
let end_addr = self.writer.stream_position()?;
self.end_addr = Some(end_addr);
end_addr
};
self.writer.seek(std::io::SeekFrom::Start(addr))?;
self.output.unwritten_mut()[..2]
.copy_from_slice(&((end_addr - addr - 2) as u16).to_le_bytes());
self.output.advance(2);
self.size_addr = None;
}
if let Some(end_addr) = self.end_addr {
self.writer.seek(std::io::SeekFrom::Start(end_addr))?;
self.end_addr = None;
self.input_flushed = 0;
self.size = 0;
}
}
}
}
Ok(())
}
}
impl<W: AsyncWrite + AsyncSeek> AsyncWrite for MetadataBlocks<W> {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
let this = self.as_mut().project();
let n = buf.len().min(this.input.len() - *this.size);
this.input[*this.size..*this.size + n].copy_from_slice(&buf[..n]);
if n < buf.len() {
ready!(self.poll_flush(cx)?);
}
Poll::Ready(Ok(n))
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
// let this = self.as_mut();
// if self.size > 0 {
// if self.size_addr.is_none() {
// self.size_addr = Some(self.writer.stream_position()?);
// self.output.unwritten_mut()[..2].copy_from_slice(&[0; 2]);
// self.output.advance(2);
// }
// if self.output.written().len() > self.output_flushed {
// let n = self
// .writer
// .write(&self.output.written()[self.output_flushed..])?;
// self.output_flushed += n;
// }
// if self.output.written().len() == self.output_flushed {
// self.output_flushed = 0;
// self.output.reset();
// }
// if self.input_flushed < self.size {
// if !self.output.unwritten().is_empty() {
// let mut input = PartialBuffer::new(&self.input[self.input_flushed..self.size]);
// self.zstd
// .get_or_insert_with(|| ZstdEncoder::new(22))
// .encode(&mut input, &mut self.output)?;
// self.input_flushed += input.written().len();
// }
// } else {
// if !self.output.unwritten().is_empty() {
// if self.zstd.as_mut().unwrap().finish(&mut self.output)? {
// self.zstd = None;
// }
// }
// if self.zstd.is_none() && self.output.written().len() == self.output_flushed {
// self.output_flushed = 0;
// self.output.reset();
// if let Some(addr) = self.size_addr {
// let end_addr = if let Some(end_addr) = self.end_addr {
// end_addr
// } else {
// let end_addr = self.writer.stream_position()?;
// self.end_addr = Some(end_addr);
// end_addr
// };
// self.writer.seek(std::io::SeekFrom::Start(addr))?;
// self.output.unwritten_mut()[..2]
// .copy_from_slice(&((end_addr - addr - 2) as u16).to_le_bytes());
// self.output.advance(2);
// self.size_addr = None;
// }
// if let Some(end_addr) = self.end_addr {
// self.writer.seek(std::io::SeekFrom::Start(end_addr))?;
// self.end_addr = None;
// self.input_flushed = 0;
// self.size = 0;
// }
// }
// }
// }
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
}