This commit is contained in:
Aiden McClelland
2025-10-20 18:05:57 -06:00
parent 716bf920f5
commit 40b00bae75
26 changed files with 736 additions and 401 deletions

246
core/Cargo.lock generated
View File

@@ -450,20 +450,16 @@ dependencies = [
[[package]]
name = "async-compression"
version = "0.4.19"
version = "0.4.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06575e6a9673580f52661c92107baabffbf41e2141373441cbcdc47cb733003c"
checksum = "5a89bce6054c720275ac2432fbba080a66a2106a44a1b804553930ca6909f4e0"
dependencies = [
"brotli",
"flate2",
"compression-codecs",
"compression-core",
"futures-core",
"futures-io",
"memchr",
"pin-project-lite",
"tokio",
"xz2",
"zstd",
"zstd-safe",
]
[[package]]
@@ -840,24 +836,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "backhand"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81e407ed987e67ac147f68f801e84a7628107acae7ac98439ee0c39d33c599dd"
dependencies = [
"deku",
"flate2",
"lz4_flex",
"solana-nohash-hasher",
"thiserror 2.0.17",
"tracing",
"xxhash-rust",
"xz2",
"zstd",
"zstd-safe",
]
[[package]]
name = "backtrace"
version = "0.3.76"
@@ -1161,9 +1139,9 @@ checksum = "2225b558afc76c596898f5f1b3fc35cfce0eb1b13635cbd7d1b2a7177dc10ccd"
[[package]]
name = "brotli"
version = "7.0.0"
version = "8.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd"
checksum = "4bd8b9603c7aa97359dbd97ecf258968c95f3adddd6db2f7e7a5bef101c84560"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
@@ -1172,9 +1150,9 @@ dependencies = [
[[package]]
name = "brotli-decompressor"
version = "4.0.3"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a334ef7c9e23abf0ce748e8cd309037da93e606ad52eb372e4ce327a0dcfbdfd"
checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
@@ -1468,6 +1446,27 @@ dependencies = [
"memchr",
]
[[package]]
name = "compression-codecs"
version = "0.4.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef8a506ec4b81c460798f572caead636d57d3d7e940f998160f52bd254bf2d23"
dependencies = [
"brotli",
"compression-core",
"flate2",
"liblzma",
"memchr",
"zstd",
"zstd-safe",
]
[[package]]
name = "compression-core"
version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e47641d3deaf41fb1538ac1f54735925e275eaf3bf4d55c81b137fba797e5cbb"
[[package]]
name = "concurrent-queue"
version = "1.2.4"
@@ -1941,16 +1940,6 @@ dependencies = [
"darling_macro 0.14.4",
]
[[package]]
name = "darling"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee"
dependencies = [
"darling_core 0.20.11",
"darling_macro 0.20.11",
]
[[package]]
name = "darling"
version = "0.21.3"
@@ -1975,20 +1964,6 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "darling_core"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim 0.11.1",
"syn 2.0.106",
]
[[package]]
name = "darling_core"
version = "0.21.3"
@@ -2014,17 +1989,6 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "darling_macro"
version = "0.20.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead"
dependencies = [
"darling_core 0.20.11",
"quote",
"syn 2.0.106",
]
[[package]]
name = "darling_macro"
version = "0.21.3"
@@ -2042,31 +2006,6 @@ version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
[[package]]
name = "deku"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9711031e209dc1306d66985363b4397d4c7b911597580340b93c9729b55f6eb"
dependencies = [
"bitvec 1.0.1",
"deku_derive",
"no_std_io2",
"rustversion",
]
[[package]]
name = "deku_derive"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58cb0719583cbe4e81fb40434ace2f0d22ccc3e39a74bb3796c22b451b4f139d"
dependencies = [
"darling 0.20.11",
"proc-macro-crate",
"proc-macro2",
"quote",
"syn 2.0.106",
]
[[package]]
name = "der"
version = "0.7.10"
@@ -2846,7 +2785,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9"
dependencies = [
"crc32fast",
"libz-rs-sys",
"miniz_oxide",
]
@@ -2868,6 +2806,12 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foldhash"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
[[package]]
name = "foreign-types"
version = "0.3.2"
@@ -3300,7 +3244,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash",
"foldhash 0.1.5",
]
[[package]]
@@ -3308,6 +3252,9 @@ name = "hashbrown"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d"
dependencies = [
"allocator-api2",
]
[[package]]
name = "hashlink"
@@ -3789,6 +3736,19 @@ dependencies = [
"serde",
]
[[package]]
name = "iddqd"
version = "0.3.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bac5efd33e0c5eb0ac45cbd210541a214dac576896ca97ba08e16e3b1079cdd8"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash 0.2.0",
"hashbrown 0.16.0",
"rustc-hash",
]
[[package]]
name = "ident_case"
version = "1.0.1"
@@ -4385,6 +4345,26 @@ dependencies = [
"windows-targets 0.53.5",
]
[[package]]
name = "liblzma"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73c36d08cad03a3fbe2c4e7bb3a9e84c57e4ee4135ed0b065cade3d98480c648"
dependencies = [
"liblzma-sys",
]
[[package]]
name = "liblzma-sys"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01b9596486f6d60c3bbe644c0e1be1aa6ccc472ad630fe8927b456973d7cb736"
dependencies = [
"cc",
"libc",
"pkg-config",
]
[[package]]
name = "libm"
version = "0.2.15"
@@ -4423,15 +4403,6 @@ dependencies = [
"version_check",
]
[[package]]
name = "libz-rs-sys"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "840db8cf39d9ec4dd794376f38acc40d0fc65eec2a8f484f7fd375b84602becd"
dependencies = [
"zlib-rs",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.15"
@@ -4474,23 +4445,6 @@ dependencies = [
"value-bag",
]
[[package]]
name = "lz4_flex"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08ab2867e3eeeca90e844d1940eab391c9dc5228783db2ed999acbc0a9ed375a"
[[package]]
name = "lzma-sys"
version = "0.1.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27"
dependencies = [
"cc",
"libc",
"pkg-config",
]
[[package]]
name = "matchers"
version = "0.2.0"
@@ -4740,15 +4694,6 @@ dependencies = [
"memoffset 0.9.1",
]
[[package]]
name = "no_std_io2"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a3564ce7035b1e4778d8cb6cacebb5d766b5e8fe5a75b9e441e33fb61a872c6"
dependencies = [
"memchr",
]
[[package]]
name = "nom"
version = "6.1.2"
@@ -7106,12 +7051,6 @@ dependencies = [
"tokio",
]
[[package]]
name = "solana-nohash-hasher"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b8a731ed60e89177c8a7ab05fe0f1511cedd3e70e773f288f9de33a9cfdc21e"
[[package]]
name = "spin"
version = "0.9.8"
@@ -7347,7 +7286,6 @@ dependencies = [
"async-stream",
"async-trait",
"axum 0.8.6",
"backhand",
"backtrace-on-stack-overflow",
"barrage",
"base32 0.5.1",
@@ -7388,6 +7326,7 @@ dependencies = [
"hyper",
"hyper-util",
"id-pool",
"iddqd",
"imbl",
"imbl-value 0.4.3 (registry+https://github.com/rust-lang/crates.io-index)",
"include_dir",
@@ -7487,6 +7426,7 @@ dependencies = [
"url",
"urlencoding",
"uuid",
"visit-rs",
"x25519-dalek",
"zbus",
"zeroize",
@@ -9592,6 +9532,29 @@ dependencies = [
"syn 2.0.106",
]
[[package]]
name = "visit-rs"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55cb1924bd417090100e20559c09bea160cd6189f3d8867c832e9985d43756fa"
dependencies = [
"async-stream",
"futures",
"serde",
"visit-rs-derive",
]
[[package]]
name = "visit-rs-derive"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2071973e6712c1caa51d425c2caa56fc0694028ba686311d2264e129aa19d14d"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.106",
]
[[package]]
name = "void"
version = "1.0.2"
@@ -10396,15 +10359,6 @@ version = "0.8.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3"
[[package]]
name = "xz2"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2"
dependencies = [
"lzma-sys",
]
[[package]]
name = "yajrc"
version = "0.1.3"
@@ -10618,12 +10572,6 @@ dependencies = [
"syn 2.0.106",
]
[[package]]
name = "zlib-rs"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f06ae92f42f5e5c42443fd094f245eb656abf56dd7cce9b8b263236565e00f2"
[[package]]
name = "zstd"
version = "0.13.3"

View File

@@ -85,16 +85,16 @@ async-acme = { version = "0.6.0", git = "https://github.com/dr-bonez/async-acme.
"use_rustls",
"use_tokio",
] }
async-compression = { version = "0.4.4", features = [
async-compression = { version = "0.4.32", features = [
"gzip",
"brotli",
"zstd",
"tokio",
] }
async-stream = "0.3.5"
async-trait = "0.1.74"
axum = { version = "0.8.4", features = ["ws"] }
barrage = "0.2.3"
backhand = "0.21.0"
backtrace-on-stack-overflow = { version = "0.3.0", optional = true }
base32 = "0.5.0"
base64 = "0.22.1"
@@ -153,6 +153,7 @@ id-pool = { version = "0.2.2", default-features = false, features = [
"serde",
"u16",
] }
iddqd = "0.3.14"
imbl = { version = "6", features = ["serde", "small-chunks"] }
imbl-value = { version = "0.4.3", features = ["ts-rs"] }
include_dir = { version = "0.7.3", features = ["metadata"] }
@@ -283,6 +284,7 @@ unix-named-pipe = "0.2.0"
url = { version = "2.4.1", features = ["serde"] }
urlencoding = "2.1.3"
uuid = { version = "1.4.1", features = ["v4"] }
visit-rs = "0.1.1"
x25519-dalek = { version = "2.0.1", features = ["static_secrets"] }
zbus = "5.1.1"
zeroize = "1.6.0"

View File

@@ -250,30 +250,6 @@ impl NetworkInterfaceInfo {
}
(&*LO, &*LOOPBACK)
}
pub fn lxc_bridge() -> (&'static GatewayId, &'static Self) {
lazy_static! {
static ref LXCBR0: GatewayId =
GatewayId::from(InternedString::intern(START9_BRIDGE_IFACE));
static ref LXC_BRIDGE: NetworkInterfaceInfo = NetworkInterfaceInfo {
name: Some(InternedString::from_static("LXC Bridge Interface")),
public: Some(false),
secure: Some(true),
ip_info: Some(IpInfo {
name: START9_BRIDGE_IFACE.into(),
scope_id: 0,
device_type: None,
subnets: [IpNet::new(HOST_IP.into(), 24).unwrap()]
.into_iter()
.collect(),
lan_ip: [IpAddr::from(HOST_IP)].into_iter().collect(),
wan_ip: None,
ntp_servers: Default::default(),
dns_servers: Default::default(),
}),
};
}
(&*LXCBR0, &*LXC_BRIDGE)
}
pub fn public(&self) -> bool {
self.public.unwrap_or_else(|| {
!self.ip_info.as_ref().map_or(true, |ip_info| {
@@ -339,7 +315,9 @@ pub struct IpInfo {
pub enum NetworkInterfaceType {
Ethernet,
Wireless,
Bridge,
Wireguard,
Loopback,
}
#[derive(Debug, Deserialize, Serialize, HasModel, TS)]

View File

@@ -415,10 +415,7 @@ impl Resolver {
{
if let Some(res) = self.net_iface.peek(|i| {
i.values()
.chain([
NetworkInterfaceInfo::loopback().1,
NetworkInterfaceInfo::lxc_bridge().1,
])
.chain([NetworkInterfaceInfo::loopback().1])
.filter_map(|i| i.ip_info.as_ref())
.find(|i| i.subnets.iter().any(|s| s.contains(&src)))
.map(|ip_info| {

View File

@@ -5,6 +5,7 @@ use std::sync::{Arc, Weak};
use futures::channel::oneshot;
use helpers::NonDetachingJoinHandle;
use id_pool::IdPool;
use iddqd::{IdOrdItem, IdOrdMap};
use imbl::OrdMap;
use models::GatewayId;
use rpc_toolkit::{from_fn_async, Context, HandlerArgs, HandlerExt, ParentHandler};
@@ -14,7 +15,7 @@ use tokio::sync::mpsc;
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::NetworkInterfaceInfo;
use crate::net::gateway::{DynInterfaceFilter, InterfaceFilter, SecureFilter};
use crate::net::gateway::{DynInterfaceFilter, InterfaceFilter};
use crate::net::utils::ipv6_is_link_local;
use crate::prelude::*;
use crate::util::serde::{display_serializable, HandlerExtSerde};
@@ -60,17 +61,10 @@ pub fn forward_api<C: Context>() -> ParentHandler<C> {
}
let mut table = Table::new();
table.add_row(row![bc => "FROM", "TO", "FILTER / GATEWAY"]);
table.add_row(row![bc => "FROM", "TO", "FILTER"]);
for (external, target) in res.0 {
table.add_row(row![external, target.target, target.filter]);
for (source, gateway) in target.gateways {
table.add_row(row![
format!("{}:{}", source, external),
target.target,
gateway
]);
}
}
table.print_tty(false)?;
@@ -85,41 +79,43 @@ struct ForwardRequest {
external: u16,
target: SocketAddr,
filter: DynInterfaceFilter,
rc: Weak<()>,
rc: Arc<()>,
}
#[derive(Clone)]
struct ForwardEntry {
external: u16,
target: SocketAddr,
prev_filter: DynInterfaceFilter,
forwards: BTreeMap<SocketAddr, GatewayId>,
rc: Weak<()>,
filter: BTreeMap<DynInterfaceFilter, (SocketAddr, Weak<()>)>,
forwards: BTreeMap<SocketAddr, (GatewayId, SocketAddr)>,
}
impl IdOrdItem for ForwardEntry {
type Key<'a> = u16;
fn key(&self) -> Self::Key<'_> {
self.external
}
iddqd::id_upcast!();
}
impl ForwardEntry {
fn new(external: u16, target: SocketAddr, rc: Weak<()>) -> Self {
fn new(external: u16) -> Self {
Self {
external,
target,
prev_filter: false.into_dyn(),
filter: BTreeMap::new(),
forwards: BTreeMap::new(),
rc,
}
}
fn take(&mut self) -> Self {
Self {
external: self.external,
target: self.target,
prev_filter: std::mem::replace(&mut self.prev_filter, false.into_dyn()),
filter: std::mem::take(&mut self.filter),
forwards: std::mem::take(&mut self.forwards),
rc: self.rc.clone(),
}
}
async fn destroy(mut self) -> Result<(), Error> {
while let Some((source, interface)) = self.forwards.pop_first() {
unforward(interface.as_str(), source, self.target).await?;
while let Some((source, (interface, target))) = self.forwards.pop_first() {
unforward(interface.as_str(), source, target).await?;
}
Ok(())
}
@@ -127,38 +123,37 @@ impl ForwardEntry {
async fn update(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
filter: Option<DynInterfaceFilter>,
) -> Result<(), Error> {
if self.rc.strong_count() == 0 {
return self.take().destroy().await;
}
let filter_ref = filter.as_ref().unwrap_or(&self.prev_filter);
let mut keep = BTreeSet::<SocketAddr>::new();
for (iface, info) in ip_info
.iter()
// .chain([NetworkInterfaceInfo::loopback()])
.filter(|(id, info)| filter_ref.filter(*id, *info))
{
if let Some(ip_info) = &info.ip_info {
for ipnet in &ip_info.subnets {
let addr = match ipnet.addr() {
IpAddr::V6(ip6) => SocketAddrV6::new(
ip6,
self.external,
0,
if ipv6_is_link_local(ip6) {
ip_info.scope_id
} else {
0
},
)
.into(),
ip => SocketAddr::new(ip, self.external),
};
keep.insert(addr);
if !self.forwards.contains_key(&addr) {
forward(iface.as_str(), addr, self.target).await?;
self.forwards.insert(addr, iface.clone());
for (iface, info) in ip_info.iter() {
if let Some(target) = self
.filter
.iter()
.filter(|(_, (_, rc))| rc.strong_count() > 0)
.find(|(filter, _)| filter.filter(iface, info))
.map(|(_, (target, _))| *target)
{
if let Some(ip_info) = &info.ip_info {
for ipnet in &ip_info.subnets {
let addr = match ipnet.addr() {
IpAddr::V6(ip6) => SocketAddrV6::new(
ip6,
self.external,
0,
if ipv6_is_link_local(ip6) {
ip_info.scope_id
} else {
0
},
)
.into(),
ip => SocketAddr::new(ip, self.external),
};
keep.insert(addr);
if !self.forwards.contains_key(&addr) {
forward(iface.as_str(), addr, target).await?;
self.forwards.insert(addr, (iface.clone(), target));
}
}
}
}
@@ -170,13 +165,10 @@ impl ForwardEntry {
.filter(|a| !keep.contains(a))
.collect::<Vec<_>>();
for rm in rm {
if let Some((source, interface)) = self.forwards.remove_entry(&rm) {
unforward(interface.as_str(), source, self.target).await?;
if let Some((source, (interface, target))) = self.forwards.remove_entry(&rm) {
unforward(interface.as_str(), source, target).await?;
}
}
if let Some(filter) = filter {
self.prev_filter = filter;
}
Ok(())
}
@@ -186,20 +178,34 @@ impl ForwardEntry {
external,
target,
filter,
rc,
mut rc,
}: ForwardRequest,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
) -> Result<(), Error> {
if external != self.external || target != self.target {
self.take().destroy().await?;
*self = Self::new(external, target, rc);
self.update(ip_info, Some(filter)).await?;
} else {
self.rc = rc;
self.update(ip_info, Some(filter).filter(|f| f != &self.prev_filter))
.await?;
) -> Result<Arc<()>, Error> {
if external != self.external {
return Err(Error::new(
eyre!("Mismatched external port in ForwardEntry"),
ErrorKind::InvalidRequest,
));
}
Ok(())
let entry = self
.filter
.entry(filter)
.or_insert_with(|| (target, Arc::downgrade(&rc)));
if entry.0 != target {
entry.0 = target;
entry.1 = Arc::downgrade(&rc);
}
if let Some(existing) = entry.1.upgrade() {
rc = existing;
} else {
entry.1 = Arc::downgrade(&rc);
}
self.update(ip_info).await?;
Ok(rc)
}
}
impl Drop for ForwardEntry {
@@ -215,17 +221,17 @@ impl Drop for ForwardEntry {
#[derive(Default, Clone)]
struct ForwardState {
state: BTreeMap<u16, ForwardEntry>,
state: IdOrdMap<ForwardEntry>,
}
impl ForwardState {
async fn handle_request(
&mut self,
request: ForwardRequest,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
) -> Result<(), Error> {
) -> Result<Arc<()>, Error> {
self.state
.entry(request.external)
.or_insert_with(|| ForwardEntry::new(request.external, request.target, Weak::new()))
.or_insert_with(|| ForwardEntry::new(request.external))
.update_request(request, ip_info)
.await
}
@@ -233,10 +239,9 @@ impl ForwardState {
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
) -> Result<(), Error> {
for entry in self.state.values_mut() {
entry.update(ip_info, None).await?;
for mut entry in self.state.iter_mut() {
entry.update(ip_info).await?;
}
self.state.retain(|_, fwd| fwd.rc.strong_count() > 0);
Ok(())
}
}
@@ -254,7 +259,6 @@ pub struct ForwardTable(pub BTreeMap<u16, ForwardTarget>);
pub struct ForwardTarget {
pub target: SocketAddr,
pub filter: String,
pub gateways: BTreeMap<SocketAddr, GatewayId>,
}
impl From<&ForwardState> for ForwardTable {
fn from(value: &ForwardState) -> Self {
@@ -262,15 +266,20 @@ impl From<&ForwardState> for ForwardTable {
value
.state
.iter()
.map(|(external, entry)| {
(
*external,
ForwardTarget {
target: entry.target,
filter: format!("{:?}", entry.prev_filter),
gateways: entry.forwards.clone(),
},
)
.flat_map(|entry| {
entry
.filter
.iter()
.filter(|(_, (_, rc))| rc.strong_count() > 0)
.map(|(filter, (target, _))| {
(
entry.external,
ForwardTarget {
target: *target,
filter: format!("{:?}", filter),
},
)
})
})
.collect(),
)
@@ -278,13 +287,15 @@ impl From<&ForwardState> for ForwardTable {
}
enum ForwardCommand {
Forward(ForwardRequest, oneshot::Sender<Result<(), Error>>),
Forward(ForwardRequest, oneshot::Sender<Result<Arc<()>, Error>>),
Sync(oneshot::Sender<Result<(), Error>>),
DumpTable(oneshot::Sender<ForwardTable>),
}
#[test]
fn test() {
use crate::net::gateway::SecureFilter;
assert_ne!(
false.into_dyn(),
SecureFilter { secure: false }.into_dyn().into_dyn()
@@ -340,13 +351,13 @@ impl PortForwardController {
external,
target,
filter,
rc: Arc::downgrade(&rc),
rc,
},
send,
))
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)?.map(|_| rc)
recv.await.map_err(err_has_exited)?
}
pub async fn gc(&self) -> Result<(), Error> {
let (send, recv) = oneshot::channel();

View File

@@ -585,18 +585,19 @@ async fn watch_ip(
loop {
until
.run(async {
let external = active_connection_proxy.state_flags().await? & 0x80 != 0;
if external {
return Ok(());
}
let device_type = match device_proxy.device_type().await? {
1 => Some(NetworkInterfaceType::Ethernet),
2 => Some(NetworkInterfaceType::Wireless),
13 => Some(NetworkInterfaceType::Bridge),
29 => Some(NetworkInterfaceType::Wireguard),
32 => Some(NetworkInterfaceType::Loopback),
_ => None,
};
if device_type == Some(NetworkInterfaceType::Loopback) {
return Ok(());
}
let name = InternedString::from(active_connection_proxy.id().await?);
let dhcp4_config = active_connection_proxy.dhcp4_config().await?;
@@ -787,13 +788,7 @@ impl NetworkInterfaceWatcher {
watch_activated: impl IntoIterator<Item = GatewayId>,
) -> Self {
let ip_info = Watch::new(OrdMap::new());
let activated = Watch::new(
watch_activated
.into_iter()
.chain([NetworkInterfaceInfo::lxc_bridge().0.clone()])
.map(|k| (k, false))
.collect(),
);
let activated = Watch::new(watch_activated.into_iter().map(|k| (k, false)).collect());
Self {
activated: activated.clone(),
ip_info: ip_info.clone(),
@@ -1384,14 +1379,12 @@ impl ListenerMap {
fn update(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
lxc_bridge: bool,
filter: &impl InterfaceFilter,
) -> Result<(), Error> {
let mut keep = BTreeSet::<SocketAddr>::new();
for (_, info) in ip_info
.iter()
.chain([NetworkInterfaceInfo::loopback()])
.chain(Some(NetworkInterfaceInfo::lxc_bridge()).filter(|_| lxc_bridge))
.filter(|(id, info)| filter.filter(*id, *info))
{
if let Some(ip_info) = &info.ip_info {
@@ -1466,10 +1459,7 @@ pub fn lookup_info_by_addr(
) -> Option<(&GatewayId, &NetworkInterfaceInfo)> {
ip_info
.iter()
.chain([
NetworkInterfaceInfo::loopback(),
NetworkInterfaceInfo::lxc_bridge(),
])
.chain([NetworkInterfaceInfo::loopback()])
.find(|(_, i)| {
i.ip_info
.as_ref()
@@ -1495,16 +1485,10 @@ impl NetworkInterfaceListener {
filter: &impl InterfaceFilter,
) -> Poll<Result<Accepted, Error>> {
while self.ip_info.poll_changed(cx).is_ready()
|| self.activated.poll_changed(cx).is_ready()
|| !DynInterfaceFilterT::eq(&self.listeners.prev_filter, filter.as_any())
{
let lxc_bridge = self.activated.peek(|a| {
a.get(NetworkInterfaceInfo::lxc_bridge().0)
.copied()
.unwrap_or_default()
});
self.ip_info
.peek_and_mark_seen(|ip_info| self.listeners.update(ip_info, lxc_bridge, filter))?;
.peek_and_mark_seen(|ip_info| self.listeners.update(ip_info, filter))?;
}
self.listeners.poll_accept(cx)
}
@@ -1562,9 +1546,12 @@ impl SelfContainedNetworkInterfaceListener {
pub fn bind(port: u16) -> Self {
let ip_info = Watch::new(OrdMap::new());
let activated = Watch::new(
[(NetworkInterfaceInfo::lxc_bridge().0.clone(), false)]
.into_iter()
.collect(),
[(
GatewayId::from(InternedString::from(START9_BRIDGE_IFACE)),
false,
)]
.into_iter()
.collect(),
);
let _watch_thread = tokio::spawn(watcher(ip_info.clone(), activated.clone())).into();
Self {

View File

@@ -6,7 +6,7 @@ use color_eyre::eyre::eyre;
use imbl::{vector, OrdMap};
use imbl_value::InternedString;
use ipnet::IpNet;
use models::{HostId, OptionExt, PackageId};
use models::{GatewayId, HostId, OptionExt, PackageId};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tracing::instrument;
@@ -16,7 +16,7 @@ use crate::db::model::Database;
use crate::error::ErrorCollection;
use crate::hostname::Hostname;
use crate::net::dns::DnsController;
use crate::net::forward::PortForwardController;
use crate::net::forward::{PortForwardController, START9_BRIDGE_IFACE};
use crate::net::gateway::{
AndFilter, DynInterfaceFilter, IdFilter, InterfaceFilter, NetworkInterfaceController, OrFilter,
PublicFilter, SecureFilter,
@@ -283,9 +283,9 @@ impl NetServiceData {
IdFilter(
NetworkInterfaceInfo::loopback().0.clone(),
),
IdFilter(
NetworkInterfaceInfo::lxc_bridge().0.clone(),
),
IdFilter(GatewayId::from(InternedString::from(
START9_BRIDGE_IFACE,
))),
)
.into_dyn(),
acme: None,

View File

@@ -3,12 +3,10 @@ use tokio::io::{AsyncSeek, AsyncWrite};
use crate::prelude::*;
use crate::util::io::TrackingIO;
#[async_trait::async_trait]
pub trait Sink: AsyncWrite + Unpin + Send {
async fn current_position(&mut self) -> Result<u64, Error>;
fn current_position(&mut self) -> impl Future<Output = Result<u64, Error>> + Send + '_;
}
#[async_trait::async_trait]
impl<S: AsyncWrite + AsyncSeek + Unpin + Send> Sink for S {
async fn current_position(&mut self) -> Result<u64, Error> {
use tokio::io::AsyncSeekExt;
@@ -17,7 +15,6 @@ impl<S: AsyncWrite + AsyncSeek + Unpin + Send> Sink for S {
}
}
#[async_trait::async_trait]
impl<W: AsyncWrite + Unpin + Send> Sink for TrackingIO<W> {
async fn current_position(&mut self) -> Result<u64, Error> {
Ok(self.position())

View File

@@ -1,14 +1,17 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use clap::Parser;
use imbl_value::InternedString;
use ipnet::Ipv4Net;
use models::GatewayId;
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerArgs, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use crate::context::CliContext;
use crate::net::gateway::{IdFilter, InterfaceFilter};
use crate::prelude::*;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::db::GatewayPort;
use crate::tunnel::wg::{ClientConfig, WgConfig, WgSubnetClients, WgSubnetConfig};
use crate::util::serde::{display_serializable, HandlerExtSerde};
@@ -27,26 +30,26 @@ pub fn tunnel_api<C: Context>() -> ParentHandler<C> {
"subnet",
subnet_api::<C>().with_about("Add, remove, or modify subnets"),
)
// .subcommand(
// "port-forward",
// ParentHandler::<C>::new()
// .subcommand(
// "add",
// from_fn_async(add_forward)
// .with_metadata("sync_db", Value::Bool(true))
// .no_display()
// .with_about("Add a new port forward")
// .with_call_remote::<CliContext>(),
// )
// .subcommand(
// "remove",
// from_fn_async(remove_forward)
// .with_metadata("sync_db", Value::Bool(true))
// .no_display()
// .with_about("Remove a port forward")
// .with_call_remote::<CliContext>(),
// ),
// )
.subcommand(
"port-forward",
ParentHandler::<C>::new()
.subcommand(
"add",
from_fn_async(add_forward)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Add a new port forward")
.with_call_remote::<CliContext>(),
)
.subcommand(
"remove",
from_fn_async(remove_forward)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Remove a port forward")
.with_call_remote::<CliContext>(),
),
)
}
#[derive(Deserialize, Serialize, Parser)]
@@ -345,3 +348,53 @@ pub async fn show_config(
(wan_addr, wg.as_port().de()?).into(),
))
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct AddPortForwardParams {
source: GatewayPort,
target: SocketAddrV4,
}
pub async fn add_forward(
ctx: TunnelContext,
AddPortForwardParams { source, target }: AddPortForwardParams,
) -> Result<(), Error> {
ctx.db
.mutate(|db| db.as_port_forwards_mut().insert(&source, &target))
.await
.result?;
let rc = ctx
.forward
.add(
source.1,
IdFilter(source.0.clone()).into_dyn(),
target.into(),
)
.await?;
ctx.active_forwards.mutate(|m| {
m.insert(source, rc);
});
Ok(())
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct RemovePortForwardParams {
source: GatewayPort,
}
pub async fn remove_forward(
ctx: TunnelContext,
RemovePortForwardParams { source, .. }: RemovePortForwardParams,
) -> Result<(), Error> {
ctx.db
.mutate(|db| db.as_port_forwards_mut().remove(&source))
.await
.result?;
if let Some(rc) = ctx.active_forwards.mutate(|m| m.remove(&source)) {
drop(rc);
ctx.forward.gc().await?;
}
Ok(())
}

View File

@@ -1,4 +1,4 @@
use std::collections::BTreeSet;
use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::ops::Deref;
use std::path::{Path, PathBuf};
@@ -10,6 +10,7 @@ use helpers::NonDetachingJoinHandle;
use http::HeaderMap;
use imbl::OrdMap;
use imbl_value::InternedString;
use models::GatewayId;
use patch_db::PatchDb;
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{CallRemote, Context, Empty};
@@ -22,12 +23,13 @@ use url::Url;
use crate::auth::Sessions;
use crate::context::config::ContextConfig;
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::NetworkInterfaceType;
use crate::middleware::auth::AuthContext;
use crate::net::forward::PortForwardController;
use crate::net::gateway::NetworkInterfaceWatcher;
use crate::net::gateway::{IdFilter, InterfaceFilter, NetworkInterfaceWatcher};
use crate::prelude::*;
use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations};
use crate::tunnel::db::TunnelDatabase;
use crate::tunnel::db::{GatewayPort, TunnelDatabase};
use crate::tunnel::TUNNEL_DEFAULT_PORT;
use crate::util::io::read_file_to_string;
use crate::util::sync::SyncMutex;
@@ -73,6 +75,7 @@ pub struct TunnelContextSeed {
pub ephemeral_sessions: SyncMutex<Sessions>,
pub net_iface: NetworkInterfaceWatcher,
pub forward: PortForwardController,
pub active_forwards: SyncMutex<BTreeMap<GatewayPort, Arc<()>>>,
pub masquerade_thread: NonDetachingJoinHandle<()>,
pub shutdown: Sender<()>,
}
@@ -114,7 +117,17 @@ impl TunnelContext {
let mut masquerade_net_iface = net_iface.subscribe();
let masquerade_thread = tokio::spawn(async move {
loop {
for iface in masquerade_net_iface.peek(|i| i.keys().cloned().collect::<Vec<_>>()) {
for iface in masquerade_net_iface.peek(|i| {
i.iter()
.filter(|(_, info)| {
dbg!(info).ip_info.as_ref().map_or(false, |i| {
dbg!(i).device_type != Some(NetworkInterfaceType::Wireguard)
})
})
.map(|(name, _)| name)
.cloned()
.collect::<Vec<_>>()
}) {
if Command::new("iptables")
.arg("-t")
.arg("nat")
@@ -128,6 +141,7 @@ impl TunnelContext {
.await
.is_err()
{
tracing::info!("Adding masquerade rule for interface {}", iface);
Command::new("iptables")
.arg("-t")
.arg("nat")
@@ -144,11 +158,23 @@ impl TunnelContext {
}
masquerade_net_iface.changed().await;
tracing::info!("Network interfaces changed, updating masquerade rules");
}
})
.into();
db.peek().await.into_wg().de()?.sync().await?;
let peek = db.peek().await;
peek.as_wg().de()?.sync().await?;
let mut active_forwards = BTreeMap::new();
for (from, to) in peek.as_port_forwards().de()?.0 {
active_forwards.insert(
from.clone(),
forward
.add(from.1, IdFilter(from.0).into_dyn(), to.into())
.await?,
);
}
Ok(Self(Arc::new(TunnelContextSeed {
listen,
@@ -164,6 +190,7 @@ impl TunnelContext {
ephemeral_sessions: SyncMutex::new(Sessions::new()),
net_iface,
forward,
active_forwards: SyncMutex::new(active_forwards),
masquerade_thread,
shutdown,
})))

View File

@@ -2,9 +2,12 @@ use std::collections::BTreeMap;
use std::net::SocketAddrV4;
use std::path::PathBuf;
use clap::builder::ValueParserFactory;
use clap::Parser;
use imbl::HashMap;
use imbl_value::InternedString;
use itertools::Itertools;
use models::{FromStrParser, GatewayId};
use patch_db::json_ptr::{JsonPointer, ROOT};
use patch_db::Dump;
use rpc_toolkit::yajrc::RpcError;
@@ -20,7 +23,52 @@ use crate::sign::AnyVerifyingKey;
use crate::tunnel::auth::SignerInfo;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::wg::WgServer;
use crate::util::serde::{apply_expr, HandlerExtSerde};
use crate::util::serde::{apply_expr, deserialize_from_str, serialize_display, HandlerExtSerde};
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct GatewayPort(pub GatewayId, pub u16);
impl std::fmt::Display for GatewayPort {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.0, self.1)
}
}
impl std::str::FromStr for GatewayPort {
type Err = crate::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut parts = s.splitn(2, ':');
let gw: GatewayId = parts
.next()
.ok_or_else(|| Error::new(eyre!("missing gateway id"), ErrorKind::ParseNetAddress))?
.parse()?;
let port: u16 = parts
.next()
.ok_or_else(|| Error::new(eyre!("missing port"), ErrorKind::ParseNetAddress))?
.parse()?;
Ok(GatewayPort(gw, port))
}
}
impl Serialize for GatewayPort {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serialize_display(self, serializer)
}
}
impl<'de> Deserialize<'de> for GatewayPort {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserialize_from_str(deserializer)
}
}
impl ValueParserFactory for GatewayPort {
type Parser = FromStrParser<Self>;
fn value_parser() -> Self::Parser {
FromStrParser::new()
}
}
#[derive(Default, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "camelCase")]
@@ -30,7 +78,20 @@ pub struct TunnelDatabase {
pub password: String,
pub auth_pubkeys: HashMap<AnyVerifyingKey, SignerInfo>,
pub wg: WgServer,
pub port_forwards: BTreeMap<SocketAddrV4, SocketAddrV4>,
pub port_forwards: PortForwards,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct PortForwards(pub BTreeMap<GatewayPort, SocketAddrV4>);
impl Map for PortForwards {
type Key = GatewayPort;
type Value = SocketAddrV4;
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
Self::key_string(key)
}
fn key_string(key: &Self::Key) -> Result<InternedString, Error> {
Ok(InternedString::from_display(key))
}
}
pub fn db_api<C: Context>() -> ParentHandler<C> {

View File

@@ -1 +0,0 @@
use crate::prelude::*;

View File

@@ -1,13 +1,11 @@
use axum::Router;
use futures::future::ready;
use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler, Server};
use rpc_toolkit::Server;
use crate::context::CliContext;
use crate::middleware::auth::Auth;
use crate::middleware::cors::Cors;
use crate::net::static_server::{bad_request, not_found, server_error};
use crate::net::web_server::{Accept, WebServer};
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::tunnel::context::TunnelContext;
@@ -15,7 +13,6 @@ pub mod api;
pub mod auth;
pub mod context;
pub mod db;
pub mod forward;
pub mod wg;
pub const TUNNEL_DEFAULT_PORT: u16 = 5960;

View File

@@ -14,7 +14,7 @@ use std::time::Duration;
use bytes::{Buf, BytesMut};
use clap::builder::ValueParserFactory;
use futures::future::{BoxFuture, Fuse};
use futures::{AsyncSeek, FutureExt, Stream, StreamExt, TryStreamExt};
use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
use helpers::{AtomicFile, NonDetachingJoinHandle};
use inotify::{EventMask, EventStream, Inotify, WatchMask};
use models::FromStrParser;
@@ -22,7 +22,8 @@ use nix::unistd::{Gid, Uid};
use serde::{Deserialize, Serialize};
use tokio::fs::{File, OpenOptions};
use tokio::io::{
duplex, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf, WriteHalf,
duplex, AsyncRead, AsyncReadExt, AsyncSeek, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf,
SeekFrom, WriteHalf,
};
use tokio::net::TcpStream;
use tokio::sync::{Notify, OwnedMutexGuard};

View File

@@ -49,6 +49,7 @@ pub mod net;
pub mod rpc;
pub mod rpc_client;
pub mod serde;
pub mod squashfs;
pub mod sync;
#[derive(Clone, Copy, Debug, ::serde::Deserialize, ::serde::Serialize)]

View File

@@ -0,0 +1,268 @@
use std::io::{Seek, Write};
use std::path::Path;
use std::task::Poll;
use async_compression::codecs::{Encode, ZstdEncoder};
use async_compression::core::util::PartialBuffer;
use futures::{ready, TryStreamExt};
use tokio::io::{AsyncSeek, AsyncWrite};
use visit_rs::{Visit, VisitAsync, VisitFields, VisitFieldsAsync, Visitor};
use crate::prelude::*;
use crate::registry::os::asset::add;
struct SquashfsSerializer<W> {
writer: W,
}
impl<W> Visitor for SquashfsSerializer<W> {
type Result = Result<(), Error>;
}
macro_rules! impl_visit_le {
($($ty:ty),*) => {
$(
impl<W: AsyncWrite + Unpin + Send> VisitAsync<SquashfsSerializer<W>> for $ty {
async fn visit_async(&self, visitor: &mut SquashfsSerializer<W>) -> Result<(), Error> {
use tokio::io::AsyncWriteExt;
visitor.writer.write_all(&self.to_le_bytes()).await?;
Ok(())
}
}
impl<W: Write> Visit<SquashfsSerializer<W>> for $ty {
fn visit(&self, visitor: &mut SquashfsSerializer<W>) -> Result<(), Error> {
visitor.writer.write_all(&self.to_le_bytes())?;
Ok(())
}
}
)*
};
}
impl_visit_le!(u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
#[derive(VisitFields)]
struct Superblock {
magic: u32, // 0x73717368
inode_count: u32,
modification_time: u32, // 0
block_size: u32,
fragment_entry_count: u32,
compression_id: u16, // 6 = zstd
block_log: u16, // log2(block_size)
flags: u16, // 0x0440
id_count: u16,
version_major: u16, // 4
version_minor: u16, // 0
root_inode_ref: u64,
bytes_used: u64,
id_table_start: u64,
xattr_id_table_start: u64,
inode_table_start: u64,
directory_table_start: u64,
fragment_table_start: u64,
export_table_start: u64,
}
impl Default for Superblock {
fn default() -> Self {
Self {
magic: 0x73717368,
inode_count: 0,
modification_time: 0,
block_size: 0,
fragment_entry_count: 0,
compression_id: 6,
block_log: 0,
flags: 0x0440,
id_count: 0,
version_major: 4,
version_minor: 0,
root_inode_ref: 0,
bytes_used: 0,
id_table_start: 0,
xattr_id_table_start: 0,
inode_table_start: 0,
directory_table_start: 0,
fragment_table_start: 0,
export_table_start: 0,
}
}
}
impl<W: AsyncWrite + Unpin + Send> VisitAsync<SquashfsSerializer<W>> for Superblock {
async fn visit_async(&self, visitor: &mut SquashfsSerializer<W>) -> Result<(), Error> {
self.visit_fields_async(visitor).try_collect().await
}
}
impl<W: Write> Visit<SquashfsSerializer<W>> for Superblock {
fn visit(&self, visitor: &mut SquashfsSerializer<W>) -> Result<(), Error> {
self.visit_fields(visitor).collect()
}
}
#[pin_project::pin_project]
pub struct MetadataBlocks<W> {
input: [u8; 8192],
input_flushed: usize,
size: usize,
size_addr: Option<u64>,
end_addr: Option<u64>,
zstd: Option<ZstdEncoder>,
output: PartialBuffer<[u8; 4096]>,
output_flushed: usize,
#[pin]
writer: W,
}
impl<W: Write + Seek> Write for MetadataBlocks<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let n = buf.len().min(self.input.len() - self.size);
self.input[self.size..self.size + n].copy_from_slice(&buf[..n]);
if n < buf.len() {
self.flush()?;
}
Ok(n)
}
fn flush(&mut self) -> std::io::Result<()> {
if self.size > 0 {
if self.size_addr.is_none() {
self.size_addr = Some(self.writer.stream_position()?);
self.output.unwritten_mut()[..2].copy_from_slice(&[0; 2]);
self.output.advance(2);
}
if self.output.written().len() > self.output_flushed {
let n = self
.writer
.write(&self.output.written()[self.output_flushed..])?;
self.output_flushed += n;
}
if self.output.written().len() == self.output_flushed {
self.output_flushed = 0;
self.output.reset();
}
if self.input_flushed < self.size {
if !self.output.unwritten().is_empty() {
let mut input = PartialBuffer::new(&self.input[self.input_flushed..self.size]);
self.zstd
.get_or_insert_with(|| ZstdEncoder::new(22))
.encode(&mut input, &mut self.output)?;
self.input_flushed += input.written().len();
}
} else {
if !self.output.unwritten().is_empty() {
if self.zstd.as_mut().unwrap().finish(&mut self.output)? {
self.zstd = None;
}
}
if self.zstd.is_none() && self.output.written().len() == self.output_flushed {
self.output_flushed = 0;
self.output.reset();
if let Some(addr) = self.size_addr {
let end_addr = if let Some(end_addr) = self.end_addr {
end_addr
} else {
let end_addr = self.writer.stream_position()?;
self.end_addr = Some(end_addr);
end_addr
};
self.writer.seek(std::io::SeekFrom::Start(addr))?;
self.output.unwritten_mut()[..2]
.copy_from_slice(&((end_addr - addr - 2) as u16).to_le_bytes());
self.output.advance(2);
self.size_addr = None;
}
if let Some(end_addr) = self.end_addr {
self.writer.seek(std::io::SeekFrom::Start(end_addr))?;
self.end_addr = None;
self.input_flushed = 0;
self.size = 0;
}
}
}
}
Ok(())
}
}
impl<W: AsyncWrite + AsyncSeek> AsyncWrite for MetadataBlocks<W> {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
let this = self.as_mut().project();
let n = buf.len().min(this.input.len() - *this.size);
this.input[*this.size..*this.size + n].copy_from_slice(&buf[..n]);
if n < buf.len() {
ready!(self.poll_flush(cx)?);
}
Poll::Ready(Ok(n))
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
// let this = self.as_mut();
// if self.size > 0 {
// if self.size_addr.is_none() {
// self.size_addr = Some(self.writer.stream_position()?);
// self.output.unwritten_mut()[..2].copy_from_slice(&[0; 2]);
// self.output.advance(2);
// }
// if self.output.written().len() > self.output_flushed {
// let n = self
// .writer
// .write(&self.output.written()[self.output_flushed..])?;
// self.output_flushed += n;
// }
// if self.output.written().len() == self.output_flushed {
// self.output_flushed = 0;
// self.output.reset();
// }
// if self.input_flushed < self.size {
// if !self.output.unwritten().is_empty() {
// let mut input = PartialBuffer::new(&self.input[self.input_flushed..self.size]);
// self.zstd
// .get_or_insert_with(|| ZstdEncoder::new(22))
// .encode(&mut input, &mut self.output)?;
// self.input_flushed += input.written().len();
// }
// } else {
// if !self.output.unwritten().is_empty() {
// if self.zstd.as_mut().unwrap().finish(&mut self.output)? {
// self.zstd = None;
// }
// }
// if self.zstd.is_none() && self.output.written().len() == self.output_flushed {
// self.output_flushed = 0;
// self.output.reset();
// if let Some(addr) = self.size_addr {
// let end_addr = if let Some(end_addr) = self.end_addr {
// end_addr
// } else {
// let end_addr = self.writer.stream_position()?;
// self.end_addr = Some(end_addr);
// end_addr
// };
// self.writer.seek(std::io::SeekFrom::Start(addr))?;
// self.output.unwritten_mut()[..2]
// .copy_from_slice(&((end_addr - addr - 2) as u16).to_le_bytes());
// self.output.advance(2);
// self.size_addr = None;
// }
// if let Some(end_addr) = self.end_addr {
// self.writer.seek(std::io::SeekFrom::Start(end_addr))?;
// self.end_addr = None;
// self.input_flushed = 0;
// self.size = 0;
// }
// }
// }
// }
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
}