mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-03-30 12:11:56 +00:00
wip
This commit is contained in:
@@ -85,16 +85,16 @@ async-acme = { version = "0.6.0", git = "https://github.com/dr-bonez/async-acme.
|
||||
"use_rustls",
|
||||
"use_tokio",
|
||||
] }
|
||||
async-compression = { version = "0.4.4", features = [
|
||||
async-compression = { version = "0.4.32", features = [
|
||||
"gzip",
|
||||
"brotli",
|
||||
"zstd",
|
||||
"tokio",
|
||||
] }
|
||||
async-stream = "0.3.5"
|
||||
async-trait = "0.1.74"
|
||||
axum = { version = "0.8.4", features = ["ws"] }
|
||||
barrage = "0.2.3"
|
||||
backhand = "0.21.0"
|
||||
backtrace-on-stack-overflow = { version = "0.3.0", optional = true }
|
||||
base32 = "0.5.0"
|
||||
base64 = "0.22.1"
|
||||
@@ -153,6 +153,7 @@ id-pool = { version = "0.2.2", default-features = false, features = [
|
||||
"serde",
|
||||
"u16",
|
||||
] }
|
||||
iddqd = "0.3.14"
|
||||
imbl = { version = "6", features = ["serde", "small-chunks"] }
|
||||
imbl-value = { version = "0.4.3", features = ["ts-rs"] }
|
||||
include_dir = { version = "0.7.3", features = ["metadata"] }
|
||||
@@ -283,6 +284,7 @@ unix-named-pipe = "0.2.0"
|
||||
url = { version = "2.4.1", features = ["serde"] }
|
||||
urlencoding = "2.1.3"
|
||||
uuid = { version = "1.4.1", features = ["v4"] }
|
||||
visit-rs = "0.1.1"
|
||||
x25519-dalek = { version = "2.0.1", features = ["static_secrets"] }
|
||||
zbus = "5.1.1"
|
||||
zeroize = "1.6.0"
|
||||
|
||||
@@ -250,30 +250,6 @@ impl NetworkInterfaceInfo {
|
||||
}
|
||||
(&*LO, &*LOOPBACK)
|
||||
}
|
||||
pub fn lxc_bridge() -> (&'static GatewayId, &'static Self) {
|
||||
lazy_static! {
|
||||
static ref LXCBR0: GatewayId =
|
||||
GatewayId::from(InternedString::intern(START9_BRIDGE_IFACE));
|
||||
static ref LXC_BRIDGE: NetworkInterfaceInfo = NetworkInterfaceInfo {
|
||||
name: Some(InternedString::from_static("LXC Bridge Interface")),
|
||||
public: Some(false),
|
||||
secure: Some(true),
|
||||
ip_info: Some(IpInfo {
|
||||
name: START9_BRIDGE_IFACE.into(),
|
||||
scope_id: 0,
|
||||
device_type: None,
|
||||
subnets: [IpNet::new(HOST_IP.into(), 24).unwrap()]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
lan_ip: [IpAddr::from(HOST_IP)].into_iter().collect(),
|
||||
wan_ip: None,
|
||||
ntp_servers: Default::default(),
|
||||
dns_servers: Default::default(),
|
||||
}),
|
||||
};
|
||||
}
|
||||
(&*LXCBR0, &*LXC_BRIDGE)
|
||||
}
|
||||
pub fn public(&self) -> bool {
|
||||
self.public.unwrap_or_else(|| {
|
||||
!self.ip_info.as_ref().map_or(true, |ip_info| {
|
||||
@@ -339,7 +315,9 @@ pub struct IpInfo {
|
||||
pub enum NetworkInterfaceType {
|
||||
Ethernet,
|
||||
Wireless,
|
||||
Bridge,
|
||||
Wireguard,
|
||||
Loopback,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, HasModel, TS)]
|
||||
|
||||
@@ -415,10 +415,7 @@ impl Resolver {
|
||||
{
|
||||
if let Some(res) = self.net_iface.peek(|i| {
|
||||
i.values()
|
||||
.chain([
|
||||
NetworkInterfaceInfo::loopback().1,
|
||||
NetworkInterfaceInfo::lxc_bridge().1,
|
||||
])
|
||||
.chain([NetworkInterfaceInfo::loopback().1])
|
||||
.filter_map(|i| i.ip_info.as_ref())
|
||||
.find(|i| i.subnets.iter().any(|s| s.contains(&src)))
|
||||
.map(|ip_info| {
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::sync::{Arc, Weak};
|
||||
use futures::channel::oneshot;
|
||||
use helpers::NonDetachingJoinHandle;
|
||||
use id_pool::IdPool;
|
||||
use iddqd::{IdOrdItem, IdOrdMap};
|
||||
use imbl::OrdMap;
|
||||
use models::GatewayId;
|
||||
use rpc_toolkit::{from_fn_async, Context, HandlerArgs, HandlerExt, ParentHandler};
|
||||
@@ -14,7 +15,7 @@ use tokio::sync::mpsc;
|
||||
|
||||
use crate::context::{CliContext, RpcContext};
|
||||
use crate::db::model::public::NetworkInterfaceInfo;
|
||||
use crate::net::gateway::{DynInterfaceFilter, InterfaceFilter, SecureFilter};
|
||||
use crate::net::gateway::{DynInterfaceFilter, InterfaceFilter};
|
||||
use crate::net::utils::ipv6_is_link_local;
|
||||
use crate::prelude::*;
|
||||
use crate::util::serde::{display_serializable, HandlerExtSerde};
|
||||
@@ -60,17 +61,10 @@ pub fn forward_api<C: Context>() -> ParentHandler<C> {
|
||||
}
|
||||
|
||||
let mut table = Table::new();
|
||||
table.add_row(row![bc => "FROM", "TO", "FILTER / GATEWAY"]);
|
||||
table.add_row(row![bc => "FROM", "TO", "FILTER"]);
|
||||
|
||||
for (external, target) in res.0 {
|
||||
table.add_row(row![external, target.target, target.filter]);
|
||||
for (source, gateway) in target.gateways {
|
||||
table.add_row(row![
|
||||
format!("{}:{}", source, external),
|
||||
target.target,
|
||||
gateway
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
table.print_tty(false)?;
|
||||
@@ -85,41 +79,43 @@ struct ForwardRequest {
|
||||
external: u16,
|
||||
target: SocketAddr,
|
||||
filter: DynInterfaceFilter,
|
||||
rc: Weak<()>,
|
||||
rc: Arc<()>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ForwardEntry {
|
||||
external: u16,
|
||||
target: SocketAddr,
|
||||
prev_filter: DynInterfaceFilter,
|
||||
forwards: BTreeMap<SocketAddr, GatewayId>,
|
||||
rc: Weak<()>,
|
||||
filter: BTreeMap<DynInterfaceFilter, (SocketAddr, Weak<()>)>,
|
||||
forwards: BTreeMap<SocketAddr, (GatewayId, SocketAddr)>,
|
||||
}
|
||||
impl IdOrdItem for ForwardEntry {
|
||||
type Key<'a> = u16;
|
||||
fn key(&self) -> Self::Key<'_> {
|
||||
self.external
|
||||
}
|
||||
|
||||
iddqd::id_upcast!();
|
||||
}
|
||||
impl ForwardEntry {
|
||||
fn new(external: u16, target: SocketAddr, rc: Weak<()>) -> Self {
|
||||
fn new(external: u16) -> Self {
|
||||
Self {
|
||||
external,
|
||||
target,
|
||||
prev_filter: false.into_dyn(),
|
||||
filter: BTreeMap::new(),
|
||||
forwards: BTreeMap::new(),
|
||||
rc,
|
||||
}
|
||||
}
|
||||
|
||||
fn take(&mut self) -> Self {
|
||||
Self {
|
||||
external: self.external,
|
||||
target: self.target,
|
||||
prev_filter: std::mem::replace(&mut self.prev_filter, false.into_dyn()),
|
||||
filter: std::mem::take(&mut self.filter),
|
||||
forwards: std::mem::take(&mut self.forwards),
|
||||
rc: self.rc.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn destroy(mut self) -> Result<(), Error> {
|
||||
while let Some((source, interface)) = self.forwards.pop_first() {
|
||||
unforward(interface.as_str(), source, self.target).await?;
|
||||
while let Some((source, (interface, target))) = self.forwards.pop_first() {
|
||||
unforward(interface.as_str(), source, target).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -127,38 +123,37 @@ impl ForwardEntry {
|
||||
async fn update(
|
||||
&mut self,
|
||||
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
|
||||
filter: Option<DynInterfaceFilter>,
|
||||
) -> Result<(), Error> {
|
||||
if self.rc.strong_count() == 0 {
|
||||
return self.take().destroy().await;
|
||||
}
|
||||
let filter_ref = filter.as_ref().unwrap_or(&self.prev_filter);
|
||||
let mut keep = BTreeSet::<SocketAddr>::new();
|
||||
for (iface, info) in ip_info
|
||||
.iter()
|
||||
// .chain([NetworkInterfaceInfo::loopback()])
|
||||
.filter(|(id, info)| filter_ref.filter(*id, *info))
|
||||
{
|
||||
if let Some(ip_info) = &info.ip_info {
|
||||
for ipnet in &ip_info.subnets {
|
||||
let addr = match ipnet.addr() {
|
||||
IpAddr::V6(ip6) => SocketAddrV6::new(
|
||||
ip6,
|
||||
self.external,
|
||||
0,
|
||||
if ipv6_is_link_local(ip6) {
|
||||
ip_info.scope_id
|
||||
} else {
|
||||
0
|
||||
},
|
||||
)
|
||||
.into(),
|
||||
ip => SocketAddr::new(ip, self.external),
|
||||
};
|
||||
keep.insert(addr);
|
||||
if !self.forwards.contains_key(&addr) {
|
||||
forward(iface.as_str(), addr, self.target).await?;
|
||||
self.forwards.insert(addr, iface.clone());
|
||||
for (iface, info) in ip_info.iter() {
|
||||
if let Some(target) = self
|
||||
.filter
|
||||
.iter()
|
||||
.filter(|(_, (_, rc))| rc.strong_count() > 0)
|
||||
.find(|(filter, _)| filter.filter(iface, info))
|
||||
.map(|(_, (target, _))| *target)
|
||||
{
|
||||
if let Some(ip_info) = &info.ip_info {
|
||||
for ipnet in &ip_info.subnets {
|
||||
let addr = match ipnet.addr() {
|
||||
IpAddr::V6(ip6) => SocketAddrV6::new(
|
||||
ip6,
|
||||
self.external,
|
||||
0,
|
||||
if ipv6_is_link_local(ip6) {
|
||||
ip_info.scope_id
|
||||
} else {
|
||||
0
|
||||
},
|
||||
)
|
||||
.into(),
|
||||
ip => SocketAddr::new(ip, self.external),
|
||||
};
|
||||
keep.insert(addr);
|
||||
if !self.forwards.contains_key(&addr) {
|
||||
forward(iface.as_str(), addr, target).await?;
|
||||
self.forwards.insert(addr, (iface.clone(), target));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -170,13 +165,10 @@ impl ForwardEntry {
|
||||
.filter(|a| !keep.contains(a))
|
||||
.collect::<Vec<_>>();
|
||||
for rm in rm {
|
||||
if let Some((source, interface)) = self.forwards.remove_entry(&rm) {
|
||||
unforward(interface.as_str(), source, self.target).await?;
|
||||
if let Some((source, (interface, target))) = self.forwards.remove_entry(&rm) {
|
||||
unforward(interface.as_str(), source, target).await?;
|
||||
}
|
||||
}
|
||||
if let Some(filter) = filter {
|
||||
self.prev_filter = filter;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -186,20 +178,34 @@ impl ForwardEntry {
|
||||
external,
|
||||
target,
|
||||
filter,
|
||||
rc,
|
||||
mut rc,
|
||||
}: ForwardRequest,
|
||||
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
|
||||
) -> Result<(), Error> {
|
||||
if external != self.external || target != self.target {
|
||||
self.take().destroy().await?;
|
||||
*self = Self::new(external, target, rc);
|
||||
self.update(ip_info, Some(filter)).await?;
|
||||
} else {
|
||||
self.rc = rc;
|
||||
self.update(ip_info, Some(filter).filter(|f| f != &self.prev_filter))
|
||||
.await?;
|
||||
) -> Result<Arc<()>, Error> {
|
||||
if external != self.external {
|
||||
return Err(Error::new(
|
||||
eyre!("Mismatched external port in ForwardEntry"),
|
||||
ErrorKind::InvalidRequest,
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
|
||||
let entry = self
|
||||
.filter
|
||||
.entry(filter)
|
||||
.or_insert_with(|| (target, Arc::downgrade(&rc)));
|
||||
if entry.0 != target {
|
||||
entry.0 = target;
|
||||
entry.1 = Arc::downgrade(&rc);
|
||||
}
|
||||
if let Some(existing) = entry.1.upgrade() {
|
||||
rc = existing;
|
||||
} else {
|
||||
entry.1 = Arc::downgrade(&rc);
|
||||
}
|
||||
|
||||
self.update(ip_info).await?;
|
||||
|
||||
Ok(rc)
|
||||
}
|
||||
}
|
||||
impl Drop for ForwardEntry {
|
||||
@@ -215,17 +221,17 @@ impl Drop for ForwardEntry {
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
struct ForwardState {
|
||||
state: BTreeMap<u16, ForwardEntry>,
|
||||
state: IdOrdMap<ForwardEntry>,
|
||||
}
|
||||
impl ForwardState {
|
||||
async fn handle_request(
|
||||
&mut self,
|
||||
request: ForwardRequest,
|
||||
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
|
||||
) -> Result<(), Error> {
|
||||
) -> Result<Arc<()>, Error> {
|
||||
self.state
|
||||
.entry(request.external)
|
||||
.or_insert_with(|| ForwardEntry::new(request.external, request.target, Weak::new()))
|
||||
.or_insert_with(|| ForwardEntry::new(request.external))
|
||||
.update_request(request, ip_info)
|
||||
.await
|
||||
}
|
||||
@@ -233,10 +239,9 @@ impl ForwardState {
|
||||
&mut self,
|
||||
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
|
||||
) -> Result<(), Error> {
|
||||
for entry in self.state.values_mut() {
|
||||
entry.update(ip_info, None).await?;
|
||||
for mut entry in self.state.iter_mut() {
|
||||
entry.update(ip_info).await?;
|
||||
}
|
||||
self.state.retain(|_, fwd| fwd.rc.strong_count() > 0);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -254,7 +259,6 @@ pub struct ForwardTable(pub BTreeMap<u16, ForwardTarget>);
|
||||
pub struct ForwardTarget {
|
||||
pub target: SocketAddr,
|
||||
pub filter: String,
|
||||
pub gateways: BTreeMap<SocketAddr, GatewayId>,
|
||||
}
|
||||
impl From<&ForwardState> for ForwardTable {
|
||||
fn from(value: &ForwardState) -> Self {
|
||||
@@ -262,15 +266,20 @@ impl From<&ForwardState> for ForwardTable {
|
||||
value
|
||||
.state
|
||||
.iter()
|
||||
.map(|(external, entry)| {
|
||||
(
|
||||
*external,
|
||||
ForwardTarget {
|
||||
target: entry.target,
|
||||
filter: format!("{:?}", entry.prev_filter),
|
||||
gateways: entry.forwards.clone(),
|
||||
},
|
||||
)
|
||||
.flat_map(|entry| {
|
||||
entry
|
||||
.filter
|
||||
.iter()
|
||||
.filter(|(_, (_, rc))| rc.strong_count() > 0)
|
||||
.map(|(filter, (target, _))| {
|
||||
(
|
||||
entry.external,
|
||||
ForwardTarget {
|
||||
target: *target,
|
||||
filter: format!("{:?}", filter),
|
||||
},
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
@@ -278,13 +287,15 @@ impl From<&ForwardState> for ForwardTable {
|
||||
}
|
||||
|
||||
enum ForwardCommand {
|
||||
Forward(ForwardRequest, oneshot::Sender<Result<(), Error>>),
|
||||
Forward(ForwardRequest, oneshot::Sender<Result<Arc<()>, Error>>),
|
||||
Sync(oneshot::Sender<Result<(), Error>>),
|
||||
DumpTable(oneshot::Sender<ForwardTable>),
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test() {
|
||||
use crate::net::gateway::SecureFilter;
|
||||
|
||||
assert_ne!(
|
||||
false.into_dyn(),
|
||||
SecureFilter { secure: false }.into_dyn().into_dyn()
|
||||
@@ -340,13 +351,13 @@ impl PortForwardController {
|
||||
external,
|
||||
target,
|
||||
filter,
|
||||
rc: Arc::downgrade(&rc),
|
||||
rc,
|
||||
},
|
||||
send,
|
||||
))
|
||||
.map_err(err_has_exited)?;
|
||||
|
||||
recv.await.map_err(err_has_exited)?.map(|_| rc)
|
||||
recv.await.map_err(err_has_exited)?
|
||||
}
|
||||
pub async fn gc(&self) -> Result<(), Error> {
|
||||
let (send, recv) = oneshot::channel();
|
||||
|
||||
@@ -585,18 +585,19 @@ async fn watch_ip(
|
||||
loop {
|
||||
until
|
||||
.run(async {
|
||||
let external = active_connection_proxy.state_flags().await? & 0x80 != 0;
|
||||
if external {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let device_type = match device_proxy.device_type().await? {
|
||||
1 => Some(NetworkInterfaceType::Ethernet),
|
||||
2 => Some(NetworkInterfaceType::Wireless),
|
||||
13 => Some(NetworkInterfaceType::Bridge),
|
||||
29 => Some(NetworkInterfaceType::Wireguard),
|
||||
32 => Some(NetworkInterfaceType::Loopback),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if device_type == Some(NetworkInterfaceType::Loopback) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let name = InternedString::from(active_connection_proxy.id().await?);
|
||||
|
||||
let dhcp4_config = active_connection_proxy.dhcp4_config().await?;
|
||||
@@ -787,13 +788,7 @@ impl NetworkInterfaceWatcher {
|
||||
watch_activated: impl IntoIterator<Item = GatewayId>,
|
||||
) -> Self {
|
||||
let ip_info = Watch::new(OrdMap::new());
|
||||
let activated = Watch::new(
|
||||
watch_activated
|
||||
.into_iter()
|
||||
.chain([NetworkInterfaceInfo::lxc_bridge().0.clone()])
|
||||
.map(|k| (k, false))
|
||||
.collect(),
|
||||
);
|
||||
let activated = Watch::new(watch_activated.into_iter().map(|k| (k, false)).collect());
|
||||
Self {
|
||||
activated: activated.clone(),
|
||||
ip_info: ip_info.clone(),
|
||||
@@ -1384,14 +1379,12 @@ impl ListenerMap {
|
||||
fn update(
|
||||
&mut self,
|
||||
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
|
||||
lxc_bridge: bool,
|
||||
filter: &impl InterfaceFilter,
|
||||
) -> Result<(), Error> {
|
||||
let mut keep = BTreeSet::<SocketAddr>::new();
|
||||
for (_, info) in ip_info
|
||||
.iter()
|
||||
.chain([NetworkInterfaceInfo::loopback()])
|
||||
.chain(Some(NetworkInterfaceInfo::lxc_bridge()).filter(|_| lxc_bridge))
|
||||
.filter(|(id, info)| filter.filter(*id, *info))
|
||||
{
|
||||
if let Some(ip_info) = &info.ip_info {
|
||||
@@ -1466,10 +1459,7 @@ pub fn lookup_info_by_addr(
|
||||
) -> Option<(&GatewayId, &NetworkInterfaceInfo)> {
|
||||
ip_info
|
||||
.iter()
|
||||
.chain([
|
||||
NetworkInterfaceInfo::loopback(),
|
||||
NetworkInterfaceInfo::lxc_bridge(),
|
||||
])
|
||||
.chain([NetworkInterfaceInfo::loopback()])
|
||||
.find(|(_, i)| {
|
||||
i.ip_info
|
||||
.as_ref()
|
||||
@@ -1495,16 +1485,10 @@ impl NetworkInterfaceListener {
|
||||
filter: &impl InterfaceFilter,
|
||||
) -> Poll<Result<Accepted, Error>> {
|
||||
while self.ip_info.poll_changed(cx).is_ready()
|
||||
|| self.activated.poll_changed(cx).is_ready()
|
||||
|| !DynInterfaceFilterT::eq(&self.listeners.prev_filter, filter.as_any())
|
||||
{
|
||||
let lxc_bridge = self.activated.peek(|a| {
|
||||
a.get(NetworkInterfaceInfo::lxc_bridge().0)
|
||||
.copied()
|
||||
.unwrap_or_default()
|
||||
});
|
||||
self.ip_info
|
||||
.peek_and_mark_seen(|ip_info| self.listeners.update(ip_info, lxc_bridge, filter))?;
|
||||
.peek_and_mark_seen(|ip_info| self.listeners.update(ip_info, filter))?;
|
||||
}
|
||||
self.listeners.poll_accept(cx)
|
||||
}
|
||||
@@ -1562,9 +1546,12 @@ impl SelfContainedNetworkInterfaceListener {
|
||||
pub fn bind(port: u16) -> Self {
|
||||
let ip_info = Watch::new(OrdMap::new());
|
||||
let activated = Watch::new(
|
||||
[(NetworkInterfaceInfo::lxc_bridge().0.clone(), false)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
[(
|
||||
GatewayId::from(InternedString::from(START9_BRIDGE_IFACE)),
|
||||
false,
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
);
|
||||
let _watch_thread = tokio::spawn(watcher(ip_info.clone(), activated.clone())).into();
|
||||
Self {
|
||||
|
||||
@@ -6,7 +6,7 @@ use color_eyre::eyre::eyre;
|
||||
use imbl::{vector, OrdMap};
|
||||
use imbl_value::InternedString;
|
||||
use ipnet::IpNet;
|
||||
use models::{HostId, OptionExt, PackageId};
|
||||
use models::{GatewayId, HostId, OptionExt, PackageId};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::instrument;
|
||||
@@ -16,7 +16,7 @@ use crate::db::model::Database;
|
||||
use crate::error::ErrorCollection;
|
||||
use crate::hostname::Hostname;
|
||||
use crate::net::dns::DnsController;
|
||||
use crate::net::forward::PortForwardController;
|
||||
use crate::net::forward::{PortForwardController, START9_BRIDGE_IFACE};
|
||||
use crate::net::gateway::{
|
||||
AndFilter, DynInterfaceFilter, IdFilter, InterfaceFilter, NetworkInterfaceController, OrFilter,
|
||||
PublicFilter, SecureFilter,
|
||||
@@ -283,9 +283,9 @@ impl NetServiceData {
|
||||
IdFilter(
|
||||
NetworkInterfaceInfo::loopback().0.clone(),
|
||||
),
|
||||
IdFilter(
|
||||
NetworkInterfaceInfo::lxc_bridge().0.clone(),
|
||||
),
|
||||
IdFilter(GatewayId::from(InternedString::from(
|
||||
START9_BRIDGE_IFACE,
|
||||
))),
|
||||
)
|
||||
.into_dyn(),
|
||||
acme: None,
|
||||
|
||||
@@ -3,12 +3,10 @@ use tokio::io::{AsyncSeek, AsyncWrite};
|
||||
use crate::prelude::*;
|
||||
use crate::util::io::TrackingIO;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Sink: AsyncWrite + Unpin + Send {
|
||||
async fn current_position(&mut self) -> Result<u64, Error>;
|
||||
fn current_position(&mut self) -> impl Future<Output = Result<u64, Error>> + Send + '_;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<S: AsyncWrite + AsyncSeek + Unpin + Send> Sink for S {
|
||||
async fn current_position(&mut self) -> Result<u64, Error> {
|
||||
use tokio::io::AsyncSeekExt;
|
||||
@@ -17,7 +15,6 @@ impl<S: AsyncWrite + AsyncSeek + Unpin + Send> Sink for S {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<W: AsyncWrite + Unpin + Send> Sink for TrackingIO<W> {
|
||||
async fn current_position(&mut self) -> Result<u64, Error> {
|
||||
Ok(self.position())
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
|
||||
|
||||
use clap::Parser;
|
||||
use imbl_value::InternedString;
|
||||
use ipnet::Ipv4Net;
|
||||
use models::GatewayId;
|
||||
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerArgs, HandlerExt, ParentHandler};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::context::CliContext;
|
||||
use crate::net::gateway::{IdFilter, InterfaceFilter};
|
||||
use crate::prelude::*;
|
||||
use crate::tunnel::context::TunnelContext;
|
||||
use crate::tunnel::db::GatewayPort;
|
||||
use crate::tunnel::wg::{ClientConfig, WgConfig, WgSubnetClients, WgSubnetConfig};
|
||||
use crate::util::serde::{display_serializable, HandlerExtSerde};
|
||||
|
||||
@@ -27,26 +30,26 @@ pub fn tunnel_api<C: Context>() -> ParentHandler<C> {
|
||||
"subnet",
|
||||
subnet_api::<C>().with_about("Add, remove, or modify subnets"),
|
||||
)
|
||||
// .subcommand(
|
||||
// "port-forward",
|
||||
// ParentHandler::<C>::new()
|
||||
// .subcommand(
|
||||
// "add",
|
||||
// from_fn_async(add_forward)
|
||||
// .with_metadata("sync_db", Value::Bool(true))
|
||||
// .no_display()
|
||||
// .with_about("Add a new port forward")
|
||||
// .with_call_remote::<CliContext>(),
|
||||
// )
|
||||
// .subcommand(
|
||||
// "remove",
|
||||
// from_fn_async(remove_forward)
|
||||
// .with_metadata("sync_db", Value::Bool(true))
|
||||
// .no_display()
|
||||
// .with_about("Remove a port forward")
|
||||
// .with_call_remote::<CliContext>(),
|
||||
// ),
|
||||
// )
|
||||
.subcommand(
|
||||
"port-forward",
|
||||
ParentHandler::<C>::new()
|
||||
.subcommand(
|
||||
"add",
|
||||
from_fn_async(add_forward)
|
||||
.with_metadata("sync_db", Value::Bool(true))
|
||||
.no_display()
|
||||
.with_about("Add a new port forward")
|
||||
.with_call_remote::<CliContext>(),
|
||||
)
|
||||
.subcommand(
|
||||
"remove",
|
||||
from_fn_async(remove_forward)
|
||||
.with_metadata("sync_db", Value::Bool(true))
|
||||
.no_display()
|
||||
.with_about("Remove a port forward")
|
||||
.with_call_remote::<CliContext>(),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Parser)]
|
||||
@@ -345,3 +348,53 @@ pub async fn show_config(
|
||||
(wan_addr, wg.as_port().de()?).into(),
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Parser)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AddPortForwardParams {
|
||||
source: GatewayPort,
|
||||
target: SocketAddrV4,
|
||||
}
|
||||
|
||||
pub async fn add_forward(
|
||||
ctx: TunnelContext,
|
||||
AddPortForwardParams { source, target }: AddPortForwardParams,
|
||||
) -> Result<(), Error> {
|
||||
ctx.db
|
||||
.mutate(|db| db.as_port_forwards_mut().insert(&source, &target))
|
||||
.await
|
||||
.result?;
|
||||
let rc = ctx
|
||||
.forward
|
||||
.add(
|
||||
source.1,
|
||||
IdFilter(source.0.clone()).into_dyn(),
|
||||
target.into(),
|
||||
)
|
||||
.await?;
|
||||
ctx.active_forwards.mutate(|m| {
|
||||
m.insert(source, rc);
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Parser)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RemovePortForwardParams {
|
||||
source: GatewayPort,
|
||||
}
|
||||
|
||||
pub async fn remove_forward(
|
||||
ctx: TunnelContext,
|
||||
RemovePortForwardParams { source, .. }: RemovePortForwardParams,
|
||||
) -> Result<(), Error> {
|
||||
ctx.db
|
||||
.mutate(|db| db.as_port_forwards_mut().remove(&source))
|
||||
.await
|
||||
.result?;
|
||||
if let Some(rc) = ctx.active_forwards.mutate(|m| m.remove(&source)) {
|
||||
drop(rc);
|
||||
ctx.forward.gc().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::collections::BTreeSet;
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
|
||||
use std::ops::Deref;
|
||||
use std::path::{Path, PathBuf};
|
||||
@@ -10,6 +10,7 @@ use helpers::NonDetachingJoinHandle;
|
||||
use http::HeaderMap;
|
||||
use imbl::OrdMap;
|
||||
use imbl_value::InternedString;
|
||||
use models::GatewayId;
|
||||
use patch_db::PatchDb;
|
||||
use rpc_toolkit::yajrc::RpcError;
|
||||
use rpc_toolkit::{CallRemote, Context, Empty};
|
||||
@@ -22,12 +23,13 @@ use url::Url;
|
||||
use crate::auth::Sessions;
|
||||
use crate::context::config::ContextConfig;
|
||||
use crate::context::{CliContext, RpcContext};
|
||||
use crate::db::model::public::NetworkInterfaceType;
|
||||
use crate::middleware::auth::AuthContext;
|
||||
use crate::net::forward::PortForwardController;
|
||||
use crate::net::gateway::NetworkInterfaceWatcher;
|
||||
use crate::net::gateway::{IdFilter, InterfaceFilter, NetworkInterfaceWatcher};
|
||||
use crate::prelude::*;
|
||||
use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations};
|
||||
use crate::tunnel::db::TunnelDatabase;
|
||||
use crate::tunnel::db::{GatewayPort, TunnelDatabase};
|
||||
use crate::tunnel::TUNNEL_DEFAULT_PORT;
|
||||
use crate::util::io::read_file_to_string;
|
||||
use crate::util::sync::SyncMutex;
|
||||
@@ -73,6 +75,7 @@ pub struct TunnelContextSeed {
|
||||
pub ephemeral_sessions: SyncMutex<Sessions>,
|
||||
pub net_iface: NetworkInterfaceWatcher,
|
||||
pub forward: PortForwardController,
|
||||
pub active_forwards: SyncMutex<BTreeMap<GatewayPort, Arc<()>>>,
|
||||
pub masquerade_thread: NonDetachingJoinHandle<()>,
|
||||
pub shutdown: Sender<()>,
|
||||
}
|
||||
@@ -114,7 +117,17 @@ impl TunnelContext {
|
||||
let mut masquerade_net_iface = net_iface.subscribe();
|
||||
let masquerade_thread = tokio::spawn(async move {
|
||||
loop {
|
||||
for iface in masquerade_net_iface.peek(|i| i.keys().cloned().collect::<Vec<_>>()) {
|
||||
for iface in masquerade_net_iface.peek(|i| {
|
||||
i.iter()
|
||||
.filter(|(_, info)| {
|
||||
dbg!(info).ip_info.as_ref().map_or(false, |i| {
|
||||
dbg!(i).device_type != Some(NetworkInterfaceType::Wireguard)
|
||||
})
|
||||
})
|
||||
.map(|(name, _)| name)
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
}) {
|
||||
if Command::new("iptables")
|
||||
.arg("-t")
|
||||
.arg("nat")
|
||||
@@ -128,6 +141,7 @@ impl TunnelContext {
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
tracing::info!("Adding masquerade rule for interface {}", iface);
|
||||
Command::new("iptables")
|
||||
.arg("-t")
|
||||
.arg("nat")
|
||||
@@ -144,11 +158,23 @@ impl TunnelContext {
|
||||
}
|
||||
|
||||
masquerade_net_iface.changed().await;
|
||||
|
||||
tracing::info!("Network interfaces changed, updating masquerade rules");
|
||||
}
|
||||
})
|
||||
.into();
|
||||
|
||||
db.peek().await.into_wg().de()?.sync().await?;
|
||||
let peek = db.peek().await;
|
||||
peek.as_wg().de()?.sync().await?;
|
||||
let mut active_forwards = BTreeMap::new();
|
||||
for (from, to) in peek.as_port_forwards().de()?.0 {
|
||||
active_forwards.insert(
|
||||
from.clone(),
|
||||
forward
|
||||
.add(from.1, IdFilter(from.0).into_dyn(), to.into())
|
||||
.await?,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self(Arc::new(TunnelContextSeed {
|
||||
listen,
|
||||
@@ -164,6 +190,7 @@ impl TunnelContext {
|
||||
ephemeral_sessions: SyncMutex::new(Sessions::new()),
|
||||
net_iface,
|
||||
forward,
|
||||
active_forwards: SyncMutex::new(active_forwards),
|
||||
masquerade_thread,
|
||||
shutdown,
|
||||
})))
|
||||
|
||||
@@ -2,9 +2,12 @@ use std::collections::BTreeMap;
|
||||
use std::net::SocketAddrV4;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use clap::builder::ValueParserFactory;
|
||||
use clap::Parser;
|
||||
use imbl::HashMap;
|
||||
use imbl_value::InternedString;
|
||||
use itertools::Itertools;
|
||||
use models::{FromStrParser, GatewayId};
|
||||
use patch_db::json_ptr::{JsonPointer, ROOT};
|
||||
use patch_db::Dump;
|
||||
use rpc_toolkit::yajrc::RpcError;
|
||||
@@ -20,7 +23,52 @@ use crate::sign::AnyVerifyingKey;
|
||||
use crate::tunnel::auth::SignerInfo;
|
||||
use crate::tunnel::context::TunnelContext;
|
||||
use crate::tunnel::wg::WgServer;
|
||||
use crate::util::serde::{apply_expr, HandlerExtSerde};
|
||||
use crate::util::serde::{apply_expr, deserialize_from_str, serialize_display, HandlerExtSerde};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct GatewayPort(pub GatewayId, pub u16);
|
||||
impl std::fmt::Display for GatewayPort {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}:{}", self.0, self.1)
|
||||
}
|
||||
}
|
||||
impl std::str::FromStr for GatewayPort {
|
||||
type Err = crate::Error;
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let mut parts = s.splitn(2, ':');
|
||||
let gw: GatewayId = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::new(eyre!("missing gateway id"), ErrorKind::ParseNetAddress))?
|
||||
.parse()?;
|
||||
let port: u16 = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::new(eyre!("missing port"), ErrorKind::ParseNetAddress))?
|
||||
.parse()?;
|
||||
Ok(GatewayPort(gw, port))
|
||||
}
|
||||
}
|
||||
impl Serialize for GatewayPort {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serialize_display(self, serializer)
|
||||
}
|
||||
}
|
||||
impl<'de> Deserialize<'de> for GatewayPort {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
deserialize_from_str(deserializer)
|
||||
}
|
||||
}
|
||||
impl ValueParserFactory for GatewayPort {
|
||||
type Parser = FromStrParser<Self>;
|
||||
fn value_parser() -> Self::Parser {
|
||||
FromStrParser::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize, Serialize, HasModel)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -30,7 +78,20 @@ pub struct TunnelDatabase {
|
||||
pub password: String,
|
||||
pub auth_pubkeys: HashMap<AnyVerifyingKey, SignerInfo>,
|
||||
pub wg: WgServer,
|
||||
pub port_forwards: BTreeMap<SocketAddrV4, SocketAddrV4>,
|
||||
pub port_forwards: PortForwards,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||
pub struct PortForwards(pub BTreeMap<GatewayPort, SocketAddrV4>);
|
||||
impl Map for PortForwards {
|
||||
type Key = GatewayPort;
|
||||
type Value = SocketAddrV4;
|
||||
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
|
||||
Self::key_string(key)
|
||||
}
|
||||
fn key_string(key: &Self::Key) -> Result<InternedString, Error> {
|
||||
Ok(InternedString::from_display(key))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn db_api<C: Context>() -> ParentHandler<C> {
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
use crate::prelude::*;
|
||||
@@ -1,13 +1,11 @@
|
||||
use axum::Router;
|
||||
use futures::future::ready;
|
||||
use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler, Server};
|
||||
use rpc_toolkit::Server;
|
||||
|
||||
use crate::context::CliContext;
|
||||
use crate::middleware::auth::Auth;
|
||||
use crate::middleware::cors::Cors;
|
||||
use crate::net::static_server::{bad_request, not_found, server_error};
|
||||
use crate::net::web_server::{Accept, WebServer};
|
||||
use crate::prelude::*;
|
||||
use crate::rpc_continuations::Guid;
|
||||
use crate::tunnel::context::TunnelContext;
|
||||
|
||||
@@ -15,7 +13,6 @@ pub mod api;
|
||||
pub mod auth;
|
||||
pub mod context;
|
||||
pub mod db;
|
||||
pub mod forward;
|
||||
pub mod wg;
|
||||
|
||||
pub const TUNNEL_DEFAULT_PORT: u16 = 5960;
|
||||
|
||||
@@ -14,7 +14,7 @@ use std::time::Duration;
|
||||
use bytes::{Buf, BytesMut};
|
||||
use clap::builder::ValueParserFactory;
|
||||
use futures::future::{BoxFuture, Fuse};
|
||||
use futures::{AsyncSeek, FutureExt, Stream, StreamExt, TryStreamExt};
|
||||
use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
|
||||
use helpers::{AtomicFile, NonDetachingJoinHandle};
|
||||
use inotify::{EventMask, EventStream, Inotify, WatchMask};
|
||||
use models::FromStrParser;
|
||||
@@ -22,7 +22,8 @@ use nix::unistd::{Gid, Uid};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::fs::{File, OpenOptions};
|
||||
use tokio::io::{
|
||||
duplex, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf, WriteHalf,
|
||||
duplex, AsyncRead, AsyncReadExt, AsyncSeek, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf,
|
||||
SeekFrom, WriteHalf,
|
||||
};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{Notify, OwnedMutexGuard};
|
||||
|
||||
@@ -49,6 +49,7 @@ pub mod net;
|
||||
pub mod rpc;
|
||||
pub mod rpc_client;
|
||||
pub mod serde;
|
||||
pub mod squashfs;
|
||||
pub mod sync;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ::serde::Deserialize, ::serde::Serialize)]
|
||||
|
||||
268
core/startos/src/util/squashfs.rs
Normal file
268
core/startos/src/util/squashfs.rs
Normal file
@@ -0,0 +1,268 @@
|
||||
use std::io::{Seek, Write};
|
||||
use std::path::Path;
|
||||
use std::task::Poll;
|
||||
|
||||
use async_compression::codecs::{Encode, ZstdEncoder};
|
||||
use async_compression::core::util::PartialBuffer;
|
||||
use futures::{ready, TryStreamExt};
|
||||
use tokio::io::{AsyncSeek, AsyncWrite};
|
||||
use visit_rs::{Visit, VisitAsync, VisitFields, VisitFieldsAsync, Visitor};
|
||||
|
||||
use crate::prelude::*;
|
||||
use crate::registry::os::asset::add;
|
||||
|
||||
struct SquashfsSerializer<W> {
|
||||
writer: W,
|
||||
}
|
||||
impl<W> Visitor for SquashfsSerializer<W> {
|
||||
type Result = Result<(), Error>;
|
||||
}
|
||||
|
||||
macro_rules! impl_visit_le {
|
||||
($($ty:ty),*) => {
|
||||
$(
|
||||
impl<W: AsyncWrite + Unpin + Send> VisitAsync<SquashfsSerializer<W>> for $ty {
|
||||
async fn visit_async(&self, visitor: &mut SquashfsSerializer<W>) -> Result<(), Error> {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
visitor.writer.write_all(&self.to_le_bytes()).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
impl<W: Write> Visit<SquashfsSerializer<W>> for $ty {
|
||||
fn visit(&self, visitor: &mut SquashfsSerializer<W>) -> Result<(), Error> {
|
||||
visitor.writer.write_all(&self.to_le_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
impl_visit_le!(u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
|
||||
|
||||
#[derive(VisitFields)]
|
||||
struct Superblock {
|
||||
magic: u32, // 0x73717368
|
||||
inode_count: u32,
|
||||
modification_time: u32, // 0
|
||||
block_size: u32,
|
||||
fragment_entry_count: u32,
|
||||
compression_id: u16, // 6 = zstd
|
||||
block_log: u16, // log2(block_size)
|
||||
flags: u16, // 0x0440
|
||||
id_count: u16,
|
||||
version_major: u16, // 4
|
||||
version_minor: u16, // 0
|
||||
root_inode_ref: u64,
|
||||
bytes_used: u64,
|
||||
id_table_start: u64,
|
||||
xattr_id_table_start: u64,
|
||||
inode_table_start: u64,
|
||||
directory_table_start: u64,
|
||||
fragment_table_start: u64,
|
||||
export_table_start: u64,
|
||||
}
|
||||
impl Default for Superblock {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
magic: 0x73717368,
|
||||
inode_count: 0,
|
||||
modification_time: 0,
|
||||
block_size: 0,
|
||||
fragment_entry_count: 0,
|
||||
compression_id: 6,
|
||||
block_log: 0,
|
||||
flags: 0x0440,
|
||||
id_count: 0,
|
||||
version_major: 4,
|
||||
version_minor: 0,
|
||||
root_inode_ref: 0,
|
||||
bytes_used: 0,
|
||||
id_table_start: 0,
|
||||
xattr_id_table_start: 0,
|
||||
inode_table_start: 0,
|
||||
directory_table_start: 0,
|
||||
fragment_table_start: 0,
|
||||
export_table_start: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<W: AsyncWrite + Unpin + Send> VisitAsync<SquashfsSerializer<W>> for Superblock {
|
||||
async fn visit_async(&self, visitor: &mut SquashfsSerializer<W>) -> Result<(), Error> {
|
||||
self.visit_fields_async(visitor).try_collect().await
|
||||
}
|
||||
}
|
||||
impl<W: Write> Visit<SquashfsSerializer<W>> for Superblock {
|
||||
fn visit(&self, visitor: &mut SquashfsSerializer<W>) -> Result<(), Error> {
|
||||
self.visit_fields(visitor).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[pin_project::pin_project]
|
||||
pub struct MetadataBlocks<W> {
|
||||
input: [u8; 8192],
|
||||
input_flushed: usize,
|
||||
size: usize,
|
||||
size_addr: Option<u64>,
|
||||
end_addr: Option<u64>,
|
||||
zstd: Option<ZstdEncoder>,
|
||||
output: PartialBuffer<[u8; 4096]>,
|
||||
output_flushed: usize,
|
||||
#[pin]
|
||||
writer: W,
|
||||
}
|
||||
impl<W: Write + Seek> Write for MetadataBlocks<W> {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
let n = buf.len().min(self.input.len() - self.size);
|
||||
self.input[self.size..self.size + n].copy_from_slice(&buf[..n]);
|
||||
if n < buf.len() {
|
||||
self.flush()?;
|
||||
}
|
||||
Ok(n)
|
||||
}
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
if self.size > 0 {
|
||||
if self.size_addr.is_none() {
|
||||
self.size_addr = Some(self.writer.stream_position()?);
|
||||
self.output.unwritten_mut()[..2].copy_from_slice(&[0; 2]);
|
||||
self.output.advance(2);
|
||||
}
|
||||
if self.output.written().len() > self.output_flushed {
|
||||
let n = self
|
||||
.writer
|
||||
.write(&self.output.written()[self.output_flushed..])?;
|
||||
self.output_flushed += n;
|
||||
}
|
||||
if self.output.written().len() == self.output_flushed {
|
||||
self.output_flushed = 0;
|
||||
self.output.reset();
|
||||
}
|
||||
if self.input_flushed < self.size {
|
||||
if !self.output.unwritten().is_empty() {
|
||||
let mut input = PartialBuffer::new(&self.input[self.input_flushed..self.size]);
|
||||
self.zstd
|
||||
.get_or_insert_with(|| ZstdEncoder::new(22))
|
||||
.encode(&mut input, &mut self.output)?;
|
||||
self.input_flushed += input.written().len();
|
||||
}
|
||||
} else {
|
||||
if !self.output.unwritten().is_empty() {
|
||||
if self.zstd.as_mut().unwrap().finish(&mut self.output)? {
|
||||
self.zstd = None;
|
||||
}
|
||||
}
|
||||
if self.zstd.is_none() && self.output.written().len() == self.output_flushed {
|
||||
self.output_flushed = 0;
|
||||
self.output.reset();
|
||||
if let Some(addr) = self.size_addr {
|
||||
let end_addr = if let Some(end_addr) = self.end_addr {
|
||||
end_addr
|
||||
} else {
|
||||
let end_addr = self.writer.stream_position()?;
|
||||
self.end_addr = Some(end_addr);
|
||||
end_addr
|
||||
};
|
||||
self.writer.seek(std::io::SeekFrom::Start(addr))?;
|
||||
self.output.unwritten_mut()[..2]
|
||||
.copy_from_slice(&((end_addr - addr - 2) as u16).to_le_bytes());
|
||||
self.output.advance(2);
|
||||
self.size_addr = None;
|
||||
}
|
||||
if let Some(end_addr) = self.end_addr {
|
||||
self.writer.seek(std::io::SeekFrom::Start(end_addr))?;
|
||||
self.end_addr = None;
|
||||
self.input_flushed = 0;
|
||||
self.size = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
impl<W: AsyncWrite + AsyncSeek> AsyncWrite for MetadataBlocks<W> {
|
||||
fn poll_write(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<std::io::Result<usize>> {
|
||||
let this = self.as_mut().project();
|
||||
let n = buf.len().min(this.input.len() - *this.size);
|
||||
this.input[*this.size..*this.size + n].copy_from_slice(&buf[..n]);
|
||||
if n < buf.len() {
|
||||
ready!(self.poll_flush(cx)?);
|
||||
}
|
||||
Poll::Ready(Ok(n))
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
// let this = self.as_mut();
|
||||
// if self.size > 0 {
|
||||
// if self.size_addr.is_none() {
|
||||
// self.size_addr = Some(self.writer.stream_position()?);
|
||||
// self.output.unwritten_mut()[..2].copy_from_slice(&[0; 2]);
|
||||
// self.output.advance(2);
|
||||
// }
|
||||
// if self.output.written().len() > self.output_flushed {
|
||||
// let n = self
|
||||
// .writer
|
||||
// .write(&self.output.written()[self.output_flushed..])?;
|
||||
// self.output_flushed += n;
|
||||
// }
|
||||
// if self.output.written().len() == self.output_flushed {
|
||||
// self.output_flushed = 0;
|
||||
// self.output.reset();
|
||||
// }
|
||||
// if self.input_flushed < self.size {
|
||||
// if !self.output.unwritten().is_empty() {
|
||||
// let mut input = PartialBuffer::new(&self.input[self.input_flushed..self.size]);
|
||||
// self.zstd
|
||||
// .get_or_insert_with(|| ZstdEncoder::new(22))
|
||||
// .encode(&mut input, &mut self.output)?;
|
||||
// self.input_flushed += input.written().len();
|
||||
// }
|
||||
// } else {
|
||||
// if !self.output.unwritten().is_empty() {
|
||||
// if self.zstd.as_mut().unwrap().finish(&mut self.output)? {
|
||||
// self.zstd = None;
|
||||
// }
|
||||
// }
|
||||
// if self.zstd.is_none() && self.output.written().len() == self.output_flushed {
|
||||
// self.output_flushed = 0;
|
||||
// self.output.reset();
|
||||
// if let Some(addr) = self.size_addr {
|
||||
// let end_addr = if let Some(end_addr) = self.end_addr {
|
||||
// end_addr
|
||||
// } else {
|
||||
// let end_addr = self.writer.stream_position()?;
|
||||
// self.end_addr = Some(end_addr);
|
||||
// end_addr
|
||||
// };
|
||||
// self.writer.seek(std::io::SeekFrom::Start(addr))?;
|
||||
// self.output.unwritten_mut()[..2]
|
||||
// .copy_from_slice(&((end_addr - addr - 2) as u16).to_le_bytes());
|
||||
// self.output.advance(2);
|
||||
// self.size_addr = None;
|
||||
// }
|
||||
// if let Some(end_addr) = self.end_addr {
|
||||
// self.writer.seek(std::io::SeekFrom::Start(end_addr))?;
|
||||
// self.end_addr = None;
|
||||
// self.input_flushed = 0;
|
||||
// self.size = 0;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::task::Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user