mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-03-30 04:01:58 +00:00
refactor public/private gateways
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::process::Command;
|
||||
@@ -48,7 +47,7 @@ async fn setup_or_init(
|
||||
update_phase.complete();
|
||||
reboot_phase.start();
|
||||
return Ok(Err(Shutdown {
|
||||
export_args: None,
|
||||
disk_guid: None,
|
||||
restart: true,
|
||||
}));
|
||||
}
|
||||
@@ -103,7 +102,7 @@ async fn setup_or_init(
|
||||
.expect("context dropped");
|
||||
|
||||
return Ok(Err(Shutdown {
|
||||
export_args: None,
|
||||
disk_guid: None,
|
||||
restart: true,
|
||||
}));
|
||||
}
|
||||
@@ -117,7 +116,9 @@ async fn setup_or_init(
|
||||
server.serve_setup(ctx.clone());
|
||||
|
||||
let mut shutdown = ctx.shutdown.subscribe();
|
||||
shutdown.recv().await.expect("context dropped");
|
||||
if let Some(shutdown) = shutdown.recv().await.expect("context dropped") {
|
||||
return Ok(Err(shutdown));
|
||||
}
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
if let Err(e) = Command::new("killall")
|
||||
@@ -183,7 +184,7 @@ async fn setup_or_init(
|
||||
let mut reboot_phase = handle.add_phase("Rebooting".into(), Some(1));
|
||||
reboot_phase.start();
|
||||
return Ok(Err(Shutdown {
|
||||
export_args: Some((disk_guid, Path::new(DATA_DIR).to_owned())),
|
||||
disk_guid: Some(disk_guid),
|
||||
restart: true,
|
||||
}));
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use cookie::{Cookie, Expiration, SameSite};
|
||||
use cookie_store::{CookieStore, RawCookie};
|
||||
use cookie_store::CookieStore;
|
||||
use imbl_value::InternedString;
|
||||
use josekit::jwk::Jwk;
|
||||
use once_cell::sync::OnceCell;
|
||||
@@ -13,7 +13,7 @@ use reqwest::Proxy;
|
||||
use reqwest_cookie_store::CookieStoreMutex;
|
||||
use rpc_toolkit::reqwest::{Client, Url};
|
||||
use rpc_toolkit::yajrc::RpcError;
|
||||
use rpc_toolkit::{call_remote_http, CallRemote, Context, Empty};
|
||||
use rpc_toolkit::{CallRemote, Context, Empty};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
|
||||
@@ -27,7 +27,6 @@ use crate::middleware::auth::AuthContext;
|
||||
use crate::prelude::*;
|
||||
use crate::rpc_continuations::Guid;
|
||||
use crate::tunnel::context::TunnelContext;
|
||||
use crate::tunnel::TUNNEL_DEFAULT_PORT;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CliContextSeed {
|
||||
|
||||
@@ -174,7 +174,7 @@ impl RpcContext {
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
webserver.try_upgrade(|a| net_ctrl.net_iface.upgrade_listener(a))?;
|
||||
webserver.try_upgrade(|a| net_ctrl.net_iface.watcher.upgrade_listener(a))?;
|
||||
let os_net_service = net_ctrl.os_bindings().await?;
|
||||
(net_ctrl, os_net_service)
|
||||
};
|
||||
|
||||
@@ -25,6 +25,7 @@ use crate::prelude::*;
|
||||
use crate::progress::FullProgressTracker;
|
||||
use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations};
|
||||
use crate::setup::SetupProgress;
|
||||
use crate::shutdown::Shutdown;
|
||||
use crate::util::net::WebSocketExt;
|
||||
use crate::MAIN_DATA;
|
||||
|
||||
@@ -71,7 +72,8 @@ pub struct SetupContextSeed {
|
||||
pub progress: FullProgressTracker,
|
||||
pub task: OnceCell<NonDetachingJoinHandle<()>>,
|
||||
pub result: OnceCell<Result<(SetupResult, RpcContext), Error>>,
|
||||
pub shutdown: Sender<()>,
|
||||
pub disk_guid: OnceCell<Arc<String>>,
|
||||
pub shutdown: Sender<Option<Shutdown>>,
|
||||
pub rpc_continuations: RpcContinuations,
|
||||
}
|
||||
|
||||
@@ -97,6 +99,7 @@ impl SetupContext {
|
||||
progress: FullProgressTracker::new(),
|
||||
task: OnceCell::new(),
|
||||
result: OnceCell::new(),
|
||||
disk_guid: OnceCell::new(),
|
||||
shutdown,
|
||||
rpc_continuations: RpcContinuations::new(),
|
||||
})))
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use exver::{Version, VersionRange};
|
||||
use imbl::{OrdMap, OrdSet};
|
||||
use imbl_value::InternedString;
|
||||
use ipnet::IpNet;
|
||||
use isocountry::CountryCode;
|
||||
use itertools::Itertools;
|
||||
use models::PackageId;
|
||||
use lazy_static::lazy_static;
|
||||
use models::{GatewayId, PackageId};
|
||||
use openssl::hash::MessageDigest;
|
||||
use patch_db::{HasModel, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -71,7 +73,8 @@ impl Public {
|
||||
net: NetInfo {
|
||||
assigned_port: None,
|
||||
assigned_ssl_port: Some(443),
|
||||
public: false,
|
||||
private_disabled: OrdSet::new(),
|
||||
public_enabled: OrdSet::new(),
|
||||
},
|
||||
},
|
||||
)]
|
||||
@@ -89,7 +92,7 @@ impl Public {
|
||||
enabled: true,
|
||||
..Default::default()
|
||||
},
|
||||
network_interfaces: BTreeMap::new(),
|
||||
network_interfaces: OrdMap::new(),
|
||||
acme: BTreeMap::new(),
|
||||
},
|
||||
status_info: ServerStatus {
|
||||
@@ -186,9 +189,9 @@ pub struct ServerInfo {
|
||||
pub struct NetworkInfo {
|
||||
pub wifi: WifiInfo,
|
||||
pub host: Host,
|
||||
#[ts(as = "BTreeMap::<String, NetworkInterfaceInfo>")]
|
||||
#[ts(as = "BTreeMap::<GatewayId, NetworkInterfaceInfo>")]
|
||||
#[serde(default)]
|
||||
pub network_interfaces: BTreeMap<InternedString, NetworkInterfaceInfo>,
|
||||
pub network_interfaces: OrdMap<GatewayId, NetworkInterfaceInfo>,
|
||||
#[serde(default)]
|
||||
pub acme: BTreeMap<AcmeProvider, AcmeSettings>,
|
||||
}
|
||||
@@ -199,9 +202,33 @@ pub struct NetworkInfo {
|
||||
#[ts(export)]
|
||||
pub struct NetworkInterfaceInfo {
|
||||
pub public: Option<bool>,
|
||||
pub secure: Option<bool>,
|
||||
pub ip_info: Option<IpInfo>,
|
||||
}
|
||||
impl NetworkInterfaceInfo {
|
||||
pub fn loopback() -> (&'static GatewayId, &'static Self) {
|
||||
lazy_static! {
|
||||
static ref LO: GatewayId = GatewayId::from("lo");
|
||||
static ref LOOPBACK: NetworkInterfaceInfo = NetworkInterfaceInfo {
|
||||
public: Some(false),
|
||||
secure: Some(true),
|
||||
ip_info: Some(IpInfo {
|
||||
name: "lo".into(),
|
||||
scope_id: 1,
|
||||
device_type: None,
|
||||
subnets: [
|
||||
IpNet::new(Ipv4Addr::LOCALHOST.into(), 8).unwrap(),
|
||||
IpNet::new(Ipv6Addr::LOCALHOST.into(), 128).unwrap(),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
wan_ip: None,
|
||||
ntp_servers: Default::default(),
|
||||
}),
|
||||
};
|
||||
}
|
||||
(&*LO, &*LOOPBACK)
|
||||
}
|
||||
pub fn public(&self) -> bool {
|
||||
self.public.unwrap_or_else(|| {
|
||||
!self.ip_info.as_ref().map_or(true, |ip_info| {
|
||||
@@ -233,6 +260,14 @@ impl NetworkInterfaceInfo {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn secure(&self) -> bool {
|
||||
self.secure.unwrap_or_else(|| {
|
||||
self.ip_info.as_ref().map_or(false, |ip_info| {
|
||||
ip_info.device_type == Some(NetworkInterfaceType::Wireguard)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize, TS, HasModel)]
|
||||
@@ -245,10 +280,10 @@ pub struct IpInfo {
|
||||
pub scope_id: u32,
|
||||
pub device_type: Option<NetworkInterfaceType>,
|
||||
#[ts(type = "string[]")]
|
||||
pub subnets: BTreeSet<IpNet>,
|
||||
pub subnets: OrdSet<IpNet>,
|
||||
pub wan_ip: Option<Ipv4Addr>,
|
||||
#[ts(type = "string[]")]
|
||||
pub ntp_servers: BTreeSet<InternedString>,
|
||||
pub ntp_servers: OrdSet<InternedString>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize, TS)]
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::marker::PhantomData;
|
||||
use std::str::FromStr;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use imbl::OrdMap;
|
||||
pub use imbl_value::Value;
|
||||
use patch_db::value::InternedString;
|
||||
pub use patch_db::{HasModel, MutateResult, PatchDb};
|
||||
@@ -199,6 +200,18 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, B> Map for OrdMap<A, B>
|
||||
where
|
||||
A: serde::Serialize + serde::de::DeserializeOwned + Clone + Ord + AsRef<str>,
|
||||
B: serde::Serialize + serde::de::DeserializeOwned + Clone,
|
||||
{
|
||||
type Key = A;
|
||||
type Value = B;
|
||||
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
|
||||
Ok(key.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Map> Model<T>
|
||||
where
|
||||
T::Value: Serialize,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use rpc_toolkit::yajrc::RpcError;
|
||||
@@ -12,7 +11,6 @@ use crate::init::SYSTEM_REBUILD_PATH;
|
||||
use crate::prelude::*;
|
||||
use crate::shutdown::Shutdown;
|
||||
use crate::util::io::delete_file;
|
||||
use crate::DATA_DIR;
|
||||
|
||||
pub fn diagnostic<C: Context>() -> ParentHandler<C> {
|
||||
ParentHandler::new()
|
||||
@@ -70,10 +68,7 @@ pub fn error(ctx: DiagnosticContext) -> Result<Arc<RpcError>, Error> {
|
||||
pub fn restart(ctx: DiagnosticContext) -> Result<(), Error> {
|
||||
ctx.shutdown
|
||||
.send(Shutdown {
|
||||
export_args: ctx
|
||||
.disk_guid
|
||||
.clone()
|
||||
.map(|guid| (guid, Path::new(DATA_DIR).to_owned())),
|
||||
disk_guid: ctx.disk_guid.clone(),
|
||||
restart: true,
|
||||
})
|
||||
.map_err(|_| eyre!("receiver dropped"))
|
||||
|
||||
@@ -35,8 +35,8 @@ impl Hostname {
|
||||
|
||||
pub fn generate_hostname() -> Hostname {
|
||||
let mut rng = rng();
|
||||
let adjective = &ADJECTIVES[rng.gen_range(0..ADJECTIVES.len())];
|
||||
let noun = &NOUNS[rng.gen_range(0..NOUNS.len())];
|
||||
let adjective = &ADJECTIVES[rng.random_range(0..ADJECTIVES.len())];
|
||||
let noun = &NOUNS[rng.random_range(0..NOUNS.len())];
|
||||
Hostname(InternedString::from_display(&lazy_format!(
|
||||
"{adjective}-{noun}"
|
||||
)))
|
||||
|
||||
@@ -216,7 +216,7 @@ pub async fn init(
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
webserver.try_upgrade(|a| net_ctrl.net_iface.upgrade_listener(a))?;
|
||||
webserver.try_upgrade(|a| net_ctrl.net_iface.watcher.upgrade_listener(a))?;
|
||||
let os_net_service = net_ctrl.os_bindings().await?;
|
||||
start_net.complete();
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::borrow::Borrow;
|
||||
use std::collections::BTreeMap;
|
||||
use std::future::Future;
|
||||
use std::net::Ipv4Addr;
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::time::Duration;
|
||||
@@ -18,7 +19,6 @@ use trust_dns_server::server::{Request, RequestHandler, ResponseHandler, Respons
|
||||
use trust_dns_server::ServerFuture;
|
||||
|
||||
use crate::net::forward::START9_BRIDGE_IFACE;
|
||||
use crate::util::sync::Watch;
|
||||
use crate::util::Invoke;
|
||||
use crate::{Error, ErrorKind, ResultExt};
|
||||
|
||||
@@ -140,7 +140,9 @@ impl RequestHandler for Resolver {
|
||||
|
||||
impl DnsController {
|
||||
#[instrument(skip_all)]
|
||||
pub async fn init(mut lxcbr_status: Watch<bool>) -> Result<Self, Error> {
|
||||
pub async fn init(
|
||||
bridge_activated: impl Future<Output = ()> + Send + Sync + 'static,
|
||||
) -> Result<Self, Error> {
|
||||
let services = Arc::new(RwLock::new(BTreeMap::new()));
|
||||
|
||||
let mut server = ServerFuture::new(Resolver {
|
||||
@@ -160,7 +162,7 @@ impl DnsController {
|
||||
.with_kind(ErrorKind::Network)?,
|
||||
);
|
||||
|
||||
lxcbr_status.wait_for(|a| *a).await;
|
||||
bridge_activated.await;
|
||||
|
||||
Command::new("resolvectl")
|
||||
.arg("dns")
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::net::{IpAddr, SocketAddr, SocketAddrV6};
|
||||
use std::sync::{Arc, Weak};
|
||||
|
||||
use futures::channel::oneshot;
|
||||
use helpers::NonDetachingJoinHandle;
|
||||
use id_pool::IdPool;
|
||||
use imbl_value::InternedString;
|
||||
use imbl::OrdMap;
|
||||
use models::GatewayId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::db::model::public::NetworkInterfaceInfo;
|
||||
use crate::net::network_interface::{DynInterfaceFilter, InterfaceFilter};
|
||||
use crate::net::utils::ipv6_is_link_local;
|
||||
use crate::prelude::*;
|
||||
use crate::util::sync::Watch;
|
||||
use crate::util::Invoke;
|
||||
@@ -39,106 +42,162 @@ impl AvailablePorts {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ForwardRequest {
|
||||
public: bool,
|
||||
external: u16,
|
||||
target: SocketAddr,
|
||||
filter: DynInterfaceFilter,
|
||||
rc: Weak<()>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct ForwardState {
|
||||
requested: BTreeMap<u16, ForwardRequest>,
|
||||
current: BTreeMap<u16, BTreeMap<InternedString, SocketAddr>>,
|
||||
struct ForwardEntry {
|
||||
external: u16,
|
||||
target: SocketAddr,
|
||||
prev_filter: DynInterfaceFilter,
|
||||
forwards: BTreeMap<SocketAddr, GatewayId>,
|
||||
rc: Weak<()>,
|
||||
}
|
||||
impl ForwardState {
|
||||
async fn sync(
|
||||
impl ForwardEntry {
|
||||
fn new(external: u16, target: SocketAddr, rc: Weak<()>) -> Self {
|
||||
Self {
|
||||
external,
|
||||
target,
|
||||
prev_filter: false.into_dyn(),
|
||||
forwards: BTreeMap::new(),
|
||||
rc,
|
||||
}
|
||||
}
|
||||
|
||||
fn take(&mut self) -> Self {
|
||||
Self {
|
||||
external: self.external,
|
||||
target: self.target,
|
||||
prev_filter: std::mem::replace(&mut self.prev_filter, false.into_dyn()),
|
||||
forwards: std::mem::take(&mut self.forwards),
|
||||
rc: self.rc.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn destroy(mut self) -> Result<(), Error> {
|
||||
while let Some((source, interface)) = self.forwards.pop_first() {
|
||||
unforward(interface.as_str(), source, self.target).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update(
|
||||
&mut self,
|
||||
interfaces: &BTreeMap<InternedString, (bool, Vec<Ipv4Addr>)>,
|
||||
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
|
||||
filter: Option<DynInterfaceFilter>,
|
||||
) -> Result<(), Error> {
|
||||
let private_interfaces = interfaces
|
||||
if self.rc.strong_count() == 0 {
|
||||
return self.take().destroy().await;
|
||||
}
|
||||
let filter_ref = filter.as_ref().unwrap_or(&self.prev_filter);
|
||||
let mut keep = BTreeSet::<SocketAddr>::new();
|
||||
for (iface, info) in ip_info
|
||||
.iter()
|
||||
.filter(|(_, (public, _))| !*public)
|
||||
.map(|(i, _)| i)
|
||||
.collect::<BTreeSet<_>>();
|
||||
let all_interfaces = interfaces.keys().collect::<BTreeSet<_>>();
|
||||
self.requested.retain(|_, req| req.rc.strong_count() > 0);
|
||||
for external in self
|
||||
.requested
|
||||
.keys()
|
||||
.chain(self.current.keys())
|
||||
.copied()
|
||||
.collect::<BTreeSet<_>>()
|
||||
.chain([NetworkInterfaceInfo::loopback()])
|
||||
.filter(|(id, info)| filter_ref.filter(*id, *info))
|
||||
{
|
||||
match (
|
||||
self.requested.get(&external),
|
||||
self.current.get_mut(&external),
|
||||
) {
|
||||
(Some(req), Some(cur)) => {
|
||||
let expected = if req.public {
|
||||
&all_interfaces
|
||||
} else {
|
||||
&private_interfaces
|
||||
if let Some(ip_info) = &info.ip_info {
|
||||
for ipnet in &ip_info.subnets {
|
||||
let addr = match ipnet.addr() {
|
||||
IpAddr::V6(ip6) => SocketAddrV6::new(
|
||||
ip6,
|
||||
self.external,
|
||||
0,
|
||||
if ipv6_is_link_local(ip6) {
|
||||
ip_info.scope_id
|
||||
} else {
|
||||
0
|
||||
},
|
||||
)
|
||||
.into(),
|
||||
ip => SocketAddr::new(ip, self.external),
|
||||
};
|
||||
let actual = cur.keys().collect::<BTreeSet<_>>();
|
||||
let mut to_rm = actual
|
||||
.difference(expected)
|
||||
.copied()
|
||||
.map(|i| (i.clone(), &interfaces[i].1))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
let mut to_add = expected
|
||||
.difference(&actual)
|
||||
.copied()
|
||||
.map(|i| (i.clone(), &interfaces[i].1))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
for interface in actual.intersection(expected).copied() {
|
||||
if cur[interface] != req.target {
|
||||
to_rm.insert(interface.clone(), &interfaces[interface].1);
|
||||
to_add.insert(interface.clone(), &interfaces[interface].1);
|
||||
}
|
||||
}
|
||||
for (interface, ips) in to_rm {
|
||||
for ip in ips {
|
||||
unforward(&*interface, (*ip, external).into(), cur[&interface]).await?;
|
||||
}
|
||||
cur.remove(&interface);
|
||||
}
|
||||
for (interface, ips) in to_add {
|
||||
cur.insert(interface.clone(), req.target);
|
||||
for ip in ips {
|
||||
forward(&*interface, (*ip, external).into(), cur[&interface]).await?;
|
||||
}
|
||||
keep.insert(addr);
|
||||
if !self.forwards.contains_key(&addr) {
|
||||
forward(iface.as_str(), addr, self.target).await?;
|
||||
self.forwards.insert(addr, iface.clone());
|
||||
}
|
||||
}
|
||||
(Some(req), None) => {
|
||||
let cur = self.current.entry(external).or_default();
|
||||
for interface in if req.public {
|
||||
&all_interfaces
|
||||
} else {
|
||||
&private_interfaces
|
||||
}
|
||||
.into_iter()
|
||||
.copied()
|
||||
{
|
||||
cur.insert(interface.clone(), req.target);
|
||||
for ip in &interfaces[interface].1 {
|
||||
forward(&**interface, (*ip, external).into(), req.target).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
(None, Some(cur)) => {
|
||||
let to_rm = cur.keys().cloned().collect::<BTreeSet<_>>();
|
||||
for interface in to_rm {
|
||||
for ip in &interfaces[&interface].1 {
|
||||
unforward(&*interface, (*ip, external).into(), cur[&interface]).await?;
|
||||
}
|
||||
cur.remove(&interface);
|
||||
}
|
||||
self.current.remove(&external);
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
let rm = self
|
||||
.forwards
|
||||
.keys()
|
||||
.copied()
|
||||
.filter(|a| !keep.contains(a))
|
||||
.collect::<Vec<_>>();
|
||||
for rm in rm {
|
||||
if let Some((source, interface)) = self.forwards.remove_entry(&rm) {
|
||||
unforward(interface.as_str(), source, self.target).await?;
|
||||
}
|
||||
}
|
||||
if let Some(filter) = filter {
|
||||
self.prev_filter = filter;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_request(
|
||||
&mut self,
|
||||
ForwardRequest {
|
||||
external,
|
||||
target,
|
||||
filter,
|
||||
rc,
|
||||
}: ForwardRequest,
|
||||
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
|
||||
) -> Result<(), Error> {
|
||||
if external != self.external || target != self.target {
|
||||
self.take().destroy().await?;
|
||||
*self = Self::new(external, target, rc);
|
||||
self.update(ip_info, Some(filter)).await?;
|
||||
} else {
|
||||
if self.prev_filter != filter {
|
||||
self.update(ip_info, Some(filter)).await?;
|
||||
}
|
||||
self.rc = rc;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
impl Drop for ForwardEntry {
|
||||
fn drop(&mut self) {
|
||||
if !self.forwards.is_empty() {
|
||||
let take = self.take();
|
||||
tokio::spawn(async move {
|
||||
take.destroy().await.log_err();
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct ForwardState {
|
||||
state: BTreeMap<u16, ForwardEntry>,
|
||||
}
|
||||
impl ForwardState {
|
||||
async fn handle_request(
|
||||
&mut self,
|
||||
request: ForwardRequest,
|
||||
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
|
||||
) -> Result<(), Error> {
|
||||
self.state
|
||||
.entry(request.external)
|
||||
.or_insert_with(|| ForwardEntry::new(request.external, request.target, Weak::new()))
|
||||
.update_request(request, ip_info)
|
||||
.await
|
||||
}
|
||||
async fn sync(
|
||||
&mut self,
|
||||
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
|
||||
) -> Result<(), Error> {
|
||||
for entry in self.state.values_mut() {
|
||||
entry.update(ip_info, None).await?;
|
||||
}
|
||||
self.state.retain(|_, fwd| !fwd.forwards.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -150,87 +209,37 @@ fn err_has_exited<T>(_: T) -> Error {
|
||||
)
|
||||
}
|
||||
|
||||
pub struct LanPortForwardController {
|
||||
req: mpsc::UnboundedSender<(
|
||||
Option<(u16, ForwardRequest)>,
|
||||
oneshot::Sender<Result<(), Error>>,
|
||||
)>,
|
||||
pub struct PortForwardController {
|
||||
req: mpsc::UnboundedSender<(Option<ForwardRequest>, oneshot::Sender<Result<(), Error>>)>,
|
||||
_thread: NonDetachingJoinHandle<()>,
|
||||
}
|
||||
impl LanPortForwardController {
|
||||
pub fn new(mut ip_info: Watch<BTreeMap<InternedString, NetworkInterfaceInfo>>) -> Self {
|
||||
let (req_send, mut req_recv) = mpsc::unbounded_channel();
|
||||
impl PortForwardController {
|
||||
pub fn new(mut ip_info: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>) -> Self {
|
||||
let (req_send, mut req_recv) = mpsc::unbounded_channel::<(
|
||||
Option<ForwardRequest>,
|
||||
oneshot::Sender<Result<(), Error>>,
|
||||
)>();
|
||||
let thread = NonDetachingJoinHandle::from(tokio::spawn(async move {
|
||||
let mut state = ForwardState::default();
|
||||
let mut interfaces = ip_info.peek_and_mark_seen(|ip_info| {
|
||||
ip_info
|
||||
.iter()
|
||||
.map(|(iface, info)| {
|
||||
(
|
||||
iface.clone(),
|
||||
(
|
||||
info.public(),
|
||||
info.ip_info.as_ref().map_or(Vec::new(), |i| {
|
||||
i.subnets
|
||||
.iter()
|
||||
.filter_map(|s| {
|
||||
if let IpAddr::V4(ip) = s.addr() {
|
||||
Some(ip)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}),
|
||||
),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
let mut reply: Option<oneshot::Sender<Result<(), Error>>> = None;
|
||||
let mut interfaces = ip_info.read_and_mark_seen();
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = req_recv.recv() => {
|
||||
if let Some((msg, re)) = msg {
|
||||
if let Some((external, req)) = msg {
|
||||
state.requested.insert(external, req);
|
||||
if let Some(req) = msg {
|
||||
re.send(state.handle_request(req, &interfaces).await).ok();
|
||||
} else {
|
||||
re.send(state.sync(&interfaces).await).ok();
|
||||
}
|
||||
reply = Some(re);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = ip_info.changed() => {
|
||||
interfaces = ip_info.peek(|ip_info| {
|
||||
ip_info
|
||||
.iter()
|
||||
.map(|(iface, info)| (iface.clone(), (
|
||||
info.public(),
|
||||
info.ip_info.as_ref().map_or(Vec::new(), |i| {
|
||||
i.subnets
|
||||
.iter()
|
||||
.filter_map(|s| {
|
||||
if let IpAddr::V4(ip) = s.addr() {
|
||||
Some(ip)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}),
|
||||
)))
|
||||
.collect()
|
||||
});
|
||||
interfaces = ip_info.read();
|
||||
state.sync(&interfaces).await.log_err();
|
||||
}
|
||||
}
|
||||
let res = state.sync(&interfaces).await;
|
||||
if let Err(e) = &res {
|
||||
tracing::error!("Error in PortForwardController: {e}");
|
||||
tracing::debug!("{e:?}");
|
||||
}
|
||||
if let Some(re) = reply.take() {
|
||||
let _ = re.send(res);
|
||||
}
|
||||
}
|
||||
}));
|
||||
Self {
|
||||
@@ -238,19 +247,22 @@ impl LanPortForwardController {
|
||||
_thread: thread,
|
||||
}
|
||||
}
|
||||
pub async fn add(&self, port: u16, public: bool, target: SocketAddr) -> Result<Arc<()>, Error> {
|
||||
pub async fn add(
|
||||
&self,
|
||||
external: u16,
|
||||
filter: impl InterfaceFilter,
|
||||
target: SocketAddr,
|
||||
) -> Result<Arc<()>, Error> {
|
||||
let rc = Arc::new(());
|
||||
let (send, recv) = oneshot::channel();
|
||||
self.req
|
||||
.send((
|
||||
Some((
|
||||
port,
|
||||
ForwardRequest {
|
||||
public,
|
||||
target,
|
||||
rc: Arc::downgrade(&rc),
|
||||
},
|
||||
)),
|
||||
Some(ForwardRequest {
|
||||
external,
|
||||
target,
|
||||
filter: filter.into_dyn(),
|
||||
rc: Arc::downgrade(&rc),
|
||||
}),
|
||||
send,
|
||||
))
|
||||
.map_err(err_has_exited)?;
|
||||
|
||||
@@ -131,7 +131,7 @@ pub fn address_api<C: Context, Kind: HostApiKind>(
|
||||
use prettytable::*;
|
||||
|
||||
if let Some(format) = params.format {
|
||||
display_serializable(format, res);
|
||||
display_serializable(format, res)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::str::FromStr;
|
||||
|
||||
use clap::builder::ValueParserFactory;
|
||||
use clap::Parser;
|
||||
use models::{FromStrParser, HostId};
|
||||
use imbl::OrdSet;
|
||||
use models::{FromStrParser, GatewayId, HostId};
|
||||
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerArgs, HandlerExt, ParentHandler};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ts_rs::TS;
|
||||
|
||||
use crate::context::{CliContext, RpcContext};
|
||||
use crate::db::model::public::NetworkInterfaceInfo;
|
||||
use crate::net::forward::AvailablePorts;
|
||||
use crate::net::host::HostApiKind;
|
||||
use crate::net::network_interface::InterfaceFilter;
|
||||
use crate::net::vhost::AlpnInfo;
|
||||
use crate::prelude::*;
|
||||
use crate::util::serde::{display_serializable, HandlerExtSerde};
|
||||
@@ -50,11 +53,14 @@ pub struct BindInfo {
|
||||
pub net: NetInfo,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Serialize, TS, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, TS, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export)]
|
||||
pub struct NetInfo {
|
||||
pub public: bool,
|
||||
#[ts(as = "BTreeSet::<GatewayId>")]
|
||||
pub private_disabled: OrdSet<GatewayId>,
|
||||
#[ts(as = "BTreeSet::<GatewayId>")]
|
||||
pub public_enabled: OrdSet<GatewayId>,
|
||||
pub assigned_port: Option<u16>,
|
||||
pub assigned_ssl_port: Option<u16>,
|
||||
}
|
||||
@@ -65,16 +71,19 @@ impl BindInfo {
|
||||
if options.add_ssl.is_some() {
|
||||
assigned_ssl_port = Some(available_ports.alloc()?);
|
||||
}
|
||||
if let Some(secure) = options.secure {
|
||||
if !secure.ssl || !options.add_ssl.is_some() {
|
||||
assigned_port = Some(available_ports.alloc()?);
|
||||
}
|
||||
if options
|
||||
.secure
|
||||
.map_or(true, |s| !(s.ssl && options.add_ssl.is_some()))
|
||||
{
|
||||
assigned_port = Some(available_ports.alloc()?);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
enabled: true,
|
||||
options,
|
||||
net: NetInfo {
|
||||
public: false,
|
||||
private_disabled: OrdSet::new(),
|
||||
public_enabled: OrdSet::new(),
|
||||
assigned_port,
|
||||
assigned_ssl_port,
|
||||
},
|
||||
@@ -88,7 +97,7 @@ impl BindInfo {
|
||||
let Self { net: mut lan, .. } = self;
|
||||
if options
|
||||
.secure
|
||||
.map_or(false, |s| !(s.ssl && options.add_ssl.is_some()))
|
||||
.map_or(true, |s| !(s.ssl && options.add_ssl.is_some()))
|
||||
// doesn't make sense to have 2 listening ports, both with ssl
|
||||
{
|
||||
lan.assigned_port = if let Some(port) = lan.assigned_port.take() {
|
||||
@@ -122,6 +131,15 @@ impl BindInfo {
|
||||
self.enabled = false;
|
||||
}
|
||||
}
|
||||
impl InterfaceFilter for NetInfo {
|
||||
fn filter(&self, id: &GatewayId, info: &NetworkInterfaceInfo) -> bool {
|
||||
if info.public() {
|
||||
self.public_enabled.contains(id)
|
||||
} else {
|
||||
!self.private_disabled.contains(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, TS)]
|
||||
#[ts(export)]
|
||||
@@ -165,12 +183,11 @@ pub fn binding<C: Context, Kind: HostApiKind>(
|
||||
}
|
||||
|
||||
let mut table = Table::new();
|
||||
table.add_row(row![bc => "INTERNAL PORT", "ENABLED", "PUBLIC", "EXTERNAL PORT", "EXTERNAL SSL PORT"]);
|
||||
table.add_row(row![bc => "INTERNAL PORT", "ENABLED", "EXTERNAL PORT", "EXTERNAL SSL PORT"]);
|
||||
for (internal, info) in res {
|
||||
table.add_row(row![
|
||||
internal,
|
||||
info.enabled,
|
||||
info.net.public,
|
||||
if let Some(port) = info.net.assigned_port {
|
||||
port.to_string()
|
||||
} else {
|
||||
@@ -192,12 +209,12 @@ pub fn binding<C: Context, Kind: HostApiKind>(
|
||||
.with_call_remote::<CliContext>(),
|
||||
)
|
||||
.subcommand(
|
||||
"set-public",
|
||||
from_fn_async(set_public::<Kind>)
|
||||
"set-gateway-enabled",
|
||||
from_fn_async(set_gateway_enabled::<Kind>)
|
||||
.with_metadata("sync_db", Value::Bool(true))
|
||||
.with_inherited(Kind::inheritance)
|
||||
.no_display()
|
||||
.with_about("Add an binding to this host")
|
||||
.with_about("Set whether this gateway should be enabled for this binding")
|
||||
.with_call_remote::<CliContext>(),
|
||||
)
|
||||
}
|
||||
@@ -215,29 +232,50 @@ pub async fn list_bindings<Kind: HostApiKind>(
|
||||
#[derive(Deserialize, Serialize, Parser, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export)]
|
||||
pub struct BindingSetPublicParams {
|
||||
pub struct BindingGatewaySetEnabledParams {
|
||||
internal_port: u16,
|
||||
gateway: GatewayId,
|
||||
#[arg(long)]
|
||||
public: Option<bool>,
|
||||
enabled: Option<bool>,
|
||||
}
|
||||
|
||||
pub async fn set_public<Kind: HostApiKind>(
|
||||
pub async fn set_gateway_enabled<Kind: HostApiKind>(
|
||||
ctx: RpcContext,
|
||||
BindingSetPublicParams {
|
||||
BindingGatewaySetEnabledParams {
|
||||
internal_port,
|
||||
public,
|
||||
}: BindingSetPublicParams,
|
||||
gateway,
|
||||
enabled,
|
||||
}: BindingGatewaySetEnabledParams,
|
||||
inheritance: Kind::Inheritance,
|
||||
) -> Result<(), Error> {
|
||||
let enabled = enabled.unwrap_or(true);
|
||||
let gateway_public = ctx
|
||||
.net_controller
|
||||
.net_iface
|
||||
.watcher
|
||||
.ip_info()
|
||||
.get(&gateway)
|
||||
.or_not_found(&gateway)?
|
||||
.public();
|
||||
ctx.db
|
||||
.mutate(|db| {
|
||||
Kind::host_for(&inheritance, db)?
|
||||
.as_bindings_mut()
|
||||
.mutate(|b| {
|
||||
b.get_mut(&internal_port)
|
||||
.or_not_found(internal_port)?
|
||||
.net
|
||||
.public = public.unwrap_or(true);
|
||||
let net = &mut b.get_mut(&internal_port).or_not_found(internal_port)?.net;
|
||||
if gateway_public {
|
||||
if enabled {
|
||||
net.public_enabled.insert(gateway);
|
||||
} else {
|
||||
net.public_enabled.remove(&gateway);
|
||||
}
|
||||
} else {
|
||||
if enabled {
|
||||
net.private_disabled.remove(&gateway);
|
||||
} else {
|
||||
net.private_disabled.insert(gateway);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -16,11 +16,14 @@ use crate::db::model::Database;
|
||||
use crate::error::ErrorCollection;
|
||||
use crate::hostname::Hostname;
|
||||
use crate::net::dns::DnsController;
|
||||
use crate::net::forward::LanPortForwardController;
|
||||
use crate::net::forward::{PortForwardController, START9_BRIDGE_IFACE};
|
||||
use crate::net::host::address::HostAddress;
|
||||
use crate::net::host::binding::{AddSslOptions, BindId, BindOptions};
|
||||
use crate::net::host::{host_for, Host, Hosts};
|
||||
use crate::net::network_interface::NetworkInterfaceController;
|
||||
use crate::net::network_interface::{
|
||||
AndFilter, DynInterfaceFilter, InterfaceFilter, LoopbackFilter, NetworkInterfaceController,
|
||||
SecureFilter,
|
||||
};
|
||||
use crate::net::service_interface::{HostnameInfo, IpHostname, OnionHostname};
|
||||
use crate::net::tor::TorController;
|
||||
use crate::net::utils::ipv6_is_local;
|
||||
@@ -36,7 +39,7 @@ pub struct NetController {
|
||||
pub(super) vhost: VHostController,
|
||||
pub(crate) net_iface: Arc<NetworkInterfaceController>,
|
||||
pub(super) dns: DnsController,
|
||||
pub(super) forward: LanPortForwardController,
|
||||
pub(super) forward: PortForwardController,
|
||||
pub(super) server_hostnames: Vec<Option<InternedString>>,
|
||||
pub(crate) callbacks: Arc<ServiceCallbacks>,
|
||||
}
|
||||
@@ -53,8 +56,13 @@ impl NetController {
|
||||
db: db.clone(),
|
||||
tor: TorController::new(tor_control, tor_socks),
|
||||
vhost: VHostController::new(db, net_iface.clone()),
|
||||
dns: DnsController::init(net_iface.lxcbr_status()).await?,
|
||||
forward: LanPortForwardController::new(net_iface.subscribe()),
|
||||
dns: DnsController::init(
|
||||
net_iface
|
||||
.watcher
|
||||
.wait_for_activated(START9_BRIDGE_IFACE.into()),
|
||||
)
|
||||
.await?,
|
||||
forward: PortForwardController::new(net_iface.watcher.subscribe()),
|
||||
net_iface,
|
||||
server_hostnames: vec![
|
||||
// LAN IP
|
||||
@@ -126,7 +134,7 @@ impl NetController {
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct HostBinds {
|
||||
forwards: BTreeMap<u16, (SocketAddr, bool, Arc<()>)>,
|
||||
forwards: BTreeMap<u16, (SocketAddr, DynInterfaceFilter, Arc<()>)>,
|
||||
vhosts: BTreeMap<(Option<InternedString>, u16), (TargetInfo, Arc<()>)>,
|
||||
tor: BTreeMap<OnionAddressV3, (OrdMap<u16, SocketAddr>, Vec<Arc<()>>)>,
|
||||
}
|
||||
@@ -217,7 +225,7 @@ impl NetServiceData {
|
||||
}
|
||||
|
||||
async fn update(&mut self, ctrl: &NetController, id: HostId, host: Host) -> Result<(), Error> {
|
||||
let mut forwards: BTreeMap<u16, (SocketAddr, bool)> = BTreeMap::new();
|
||||
let mut forwards: BTreeMap<u16, (SocketAddr, DynInterfaceFilter)> = BTreeMap::new();
|
||||
let mut vhosts: BTreeMap<(Option<InternedString>, u16), TargetInfo> = BTreeMap::new();
|
||||
let mut tor: BTreeMap<OnionAddressV3, (TorSecretKeyV3, OrdMap<u16, SocketAddr>)> =
|
||||
BTreeMap::new();
|
||||
@@ -228,7 +236,7 @@ impl NetServiceData {
|
||||
|
||||
// LAN
|
||||
let server_info = peek.as_public().as_server_info();
|
||||
let net_ifaces = ctrl.net_iface.ip_info();
|
||||
let net_ifaces = ctrl.net_iface.watcher.ip_info();
|
||||
let hostname = server_info.as_hostname().de()?;
|
||||
for (port, bind) in &host.bindings {
|
||||
if !bind.enabled {
|
||||
@@ -255,7 +263,7 @@ impl NetServiceData {
|
||||
vhosts.insert(
|
||||
(hostname, external),
|
||||
TargetInfo {
|
||||
public: bind.net.public,
|
||||
filter: bind.net.clone().into_dyn(),
|
||||
acme: None,
|
||||
addr,
|
||||
connect_ssl: connect_ssl.clone(),
|
||||
@@ -270,7 +278,7 @@ impl NetServiceData {
|
||||
vhosts.insert(
|
||||
(Some(hostname), external),
|
||||
TargetInfo {
|
||||
public: false,
|
||||
filter: LoopbackFilter.into_dyn(),
|
||||
acme: None,
|
||||
addr,
|
||||
connect_ssl: connect_ssl.clone(),
|
||||
@@ -286,11 +294,11 @@ impl NetServiceData {
|
||||
if hostnames.insert(address.clone()) {
|
||||
let address = Some(address.clone());
|
||||
if ssl.preferred_external_port == 443 {
|
||||
if public && bind.net.public {
|
||||
if public {
|
||||
vhosts.insert(
|
||||
(address.clone(), 5443),
|
||||
TargetInfo {
|
||||
public: false,
|
||||
filter: bind.net.clone().into_dyn(),
|
||||
acme: acme.clone(),
|
||||
addr,
|
||||
connect_ssl: connect_ssl.clone(),
|
||||
@@ -300,7 +308,7 @@ impl NetServiceData {
|
||||
vhosts.insert(
|
||||
(address.clone(), 443),
|
||||
TargetInfo {
|
||||
public: public && bind.net.public,
|
||||
filter: bind.net.clone().into_dyn(),
|
||||
acme,
|
||||
addr,
|
||||
connect_ssl: connect_ssl.clone(),
|
||||
@@ -310,7 +318,7 @@ impl NetServiceData {
|
||||
vhosts.insert(
|
||||
(address.clone(), external),
|
||||
TargetInfo {
|
||||
public: public && bind.net.public,
|
||||
filter: bind.net.clone().into_dyn(),
|
||||
acme,
|
||||
addr,
|
||||
connect_ssl: connect_ssl.clone(),
|
||||
@@ -322,28 +330,35 @@ impl NetServiceData {
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(security) = bind.options.secure {
|
||||
if bind.options.add_ssl.is_some() && security.ssl {
|
||||
// doesn't make sense to have 2 listening ports, both with ssl
|
||||
} else {
|
||||
let external = bind.net.assigned_port.or_not_found("assigned lan port")?;
|
||||
forwards.insert(external, ((self.ip, *port).into(), bind.net.public));
|
||||
}
|
||||
if bind
|
||||
.options
|
||||
.secure
|
||||
.map_or(true, |s| !(s.ssl && bind.options.add_ssl.is_some()))
|
||||
{
|
||||
let external = bind.net.assigned_port.or_not_found("assigned lan port")?;
|
||||
forwards.insert(
|
||||
external,
|
||||
(
|
||||
(self.ip, *port).into(),
|
||||
AndFilter(
|
||||
SecureFilter {
|
||||
secure: bind.options.secure.is_some(),
|
||||
},
|
||||
bind.net.clone(),
|
||||
)
|
||||
.into_dyn(),
|
||||
),
|
||||
);
|
||||
}
|
||||
let mut bind_hostname_info: Vec<HostnameInfo> =
|
||||
hostname_info.remove(port).unwrap_or_default();
|
||||
for (interface, public, ip_info) in
|
||||
net_ifaces.iter().filter_map(|(interface, info)| {
|
||||
if let Some(ip_info) = &info.ip_info {
|
||||
Some((interface, info.public(), ip_info))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
for (interface, info) in net_ifaces
|
||||
.iter()
|
||||
.filter(|(id, info)| bind.net.filter(id, info))
|
||||
{
|
||||
if !public {
|
||||
if !info.public() {
|
||||
bind_hostname_info.push(HostnameInfo::Ip {
|
||||
network_interface_id: interface.clone(),
|
||||
gateway_id: interface.clone(),
|
||||
public: false,
|
||||
hostname: IpHostname::Local {
|
||||
value: InternedString::from_display(&{
|
||||
@@ -357,47 +372,44 @@ impl NetServiceData {
|
||||
}
|
||||
for address in host.addresses() {
|
||||
if let HostAddress::Domain {
|
||||
address,
|
||||
public: domain_public,
|
||||
..
|
||||
address, public, ..
|
||||
} = address
|
||||
{
|
||||
if !public || (domain_public && bind.net.public) {
|
||||
if bind
|
||||
.options
|
||||
.add_ssl
|
||||
.as_ref()
|
||||
.map_or(false, |ssl| ssl.preferred_external_port == 443)
|
||||
{
|
||||
bind_hostname_info.push(HostnameInfo::Ip {
|
||||
network_interface_id: interface.clone(),
|
||||
public: public && domain_public && bind.net.public, // TODO: check if port forward is active
|
||||
hostname: IpHostname::Domain {
|
||||
domain: address.clone(),
|
||||
subdomain: None,
|
||||
port: None,
|
||||
ssl_port: Some(443),
|
||||
},
|
||||
});
|
||||
} else {
|
||||
bind_hostname_info.push(HostnameInfo::Ip {
|
||||
network_interface_id: interface.clone(),
|
||||
public,
|
||||
hostname: IpHostname::Domain {
|
||||
domain: address.clone(),
|
||||
subdomain: None,
|
||||
port: bind.net.assigned_port,
|
||||
ssl_port: bind.net.assigned_ssl_port,
|
||||
},
|
||||
});
|
||||
}
|
||||
if bind
|
||||
.options
|
||||
.add_ssl
|
||||
.as_ref()
|
||||
.map_or(false, |ssl| ssl.preferred_external_port == 443)
|
||||
{
|
||||
bind_hostname_info.push(HostnameInfo::Ip {
|
||||
gateway_id: interface.clone(),
|
||||
public, // TODO: check if port forward is active
|
||||
hostname: IpHostname::Domain {
|
||||
domain: address.clone(),
|
||||
subdomain: None,
|
||||
port: None,
|
||||
ssl_port: Some(443),
|
||||
},
|
||||
});
|
||||
} else {
|
||||
bind_hostname_info.push(HostnameInfo::Ip {
|
||||
gateway_id: interface.clone(),
|
||||
public,
|
||||
hostname: IpHostname::Domain {
|
||||
domain: address.clone(),
|
||||
subdomain: None,
|
||||
port: bind.net.assigned_port,
|
||||
ssl_port: bind.net.assigned_ssl_port,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
if !public || bind.net.public {
|
||||
if let Some(ip_info) = &info.ip_info {
|
||||
let public = info.public();
|
||||
if let Some(wan_ip) = ip_info.wan_ip.filter(|_| public) {
|
||||
bind_hostname_info.push(HostnameInfo::Ip {
|
||||
network_interface_id: interface.clone(),
|
||||
gateway_id: interface.clone(),
|
||||
public,
|
||||
hostname: IpHostname::Ipv4 {
|
||||
value: wan_ip,
|
||||
@@ -411,7 +423,7 @@ impl NetServiceData {
|
||||
IpNet::V4(net) => {
|
||||
if !public {
|
||||
bind_hostname_info.push(HostnameInfo::Ip {
|
||||
network_interface_id: interface.clone(),
|
||||
gateway_id: interface.clone(),
|
||||
public,
|
||||
hostname: IpHostname::Ipv4 {
|
||||
value: net.addr(),
|
||||
@@ -423,7 +435,7 @@ impl NetServiceData {
|
||||
}
|
||||
IpNet::V6(net) => {
|
||||
bind_hostname_info.push(HostnameInfo::Ip {
|
||||
network_interface_id: interface.clone(),
|
||||
gateway_id: interface.clone(),
|
||||
public: public && !ipv6_is_local(net.addr()),
|
||||
hostname: IpHostname::Ipv6 {
|
||||
value: net.addr(),
|
||||
@@ -509,8 +521,8 @@ impl NetServiceData {
|
||||
.collect::<BTreeSet<_>>();
|
||||
for external in all {
|
||||
let mut prev = binds.forwards.remove(&external);
|
||||
if let Some((internal, public)) = forwards.remove(&external) {
|
||||
prev = prev.filter(|(i, p, _)| i == &internal && *p == public);
|
||||
if let Some((internal, filter)) = forwards.remove(&external) {
|
||||
prev = prev.filter(|(i, f, _)| i == &internal && *f == filter);
|
||||
binds.forwards.insert(
|
||||
external,
|
||||
if let Some(prev) = prev {
|
||||
@@ -518,8 +530,8 @@ impl NetServiceData {
|
||||
} else {
|
||||
(
|
||||
internal,
|
||||
public,
|
||||
ctrl.forward.add(external, public, internal).await?,
|
||||
filter.clone(),
|
||||
ctrl.forward.add(external, filter, internal).await?,
|
||||
)
|
||||
},
|
||||
);
|
||||
@@ -662,7 +674,7 @@ impl NetService {
|
||||
}
|
||||
|
||||
fn new(data: NetServiceData) -> Result<Self, Error> {
|
||||
let mut ip_info = data.net_controller()?.net_iface.subscribe();
|
||||
let mut ip_info = data.net_controller()?.net_iface.watcher.subscribe();
|
||||
let data = Arc::new(Mutex::new(data));
|
||||
let thread_data = data.clone();
|
||||
let sync_task = tokio::spawn(async move {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,7 @@ use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
|
||||
use imbl_value::InternedString;
|
||||
use lazy_format::lazy_format;
|
||||
use models::{HostId, ServiceInterfaceId};
|
||||
use models::{GatewayId, HostId, ServiceInterfaceId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ts_rs::TS;
|
||||
|
||||
@@ -14,7 +14,7 @@ use ts_rs::TS;
|
||||
pub enum HostnameInfo {
|
||||
Ip {
|
||||
#[ts(type = "string")]
|
||||
network_interface_id: InternedString,
|
||||
gateway_id: GatewayId,
|
||||
public: bool,
|
||||
hostname: IpHostname,
|
||||
},
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use clap::Parser;
|
||||
use imbl_value::InternedString;
|
||||
use models::GatewayId;
|
||||
use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::process::Command;
|
||||
@@ -44,7 +45,7 @@ pub async fn add_tunnel(
|
||||
config,
|
||||
public,
|
||||
}: AddTunnelParams,
|
||||
) -> Result<InternedString, Error> {
|
||||
) -> Result<GatewayId, Error> {
|
||||
let existing = ctx
|
||||
.db
|
||||
.peek()
|
||||
@@ -54,17 +55,17 @@ pub async fn add_tunnel(
|
||||
.into_network()
|
||||
.into_network_interfaces()
|
||||
.keys()?;
|
||||
let mut iface = InternedString::intern("wg0");
|
||||
let mut iface = GatewayId::from("wg0");
|
||||
for id in 1.. {
|
||||
if !existing.contains(&iface) {
|
||||
break;
|
||||
}
|
||||
iface = InternedString::from_display(&lazy_format!("wg{id}"));
|
||||
iface = InternedString::from_display(&lazy_format!("wg{id}")).into();
|
||||
}
|
||||
let tmpdir = TmpDir::new().await?;
|
||||
let conf = tmpdir.join(&*iface).with_extension("conf");
|
||||
let conf = tmpdir.join(&iface).with_extension("conf");
|
||||
write_file_atomic(&conf, &config).await?;
|
||||
let mut ifaces = ctx.net_controller.net_iface.subscribe();
|
||||
let mut ifaces = ctx.net_controller.net_iface.watcher.subscribe();
|
||||
Command::new("nmcli")
|
||||
.arg("connection")
|
||||
.arg("import")
|
||||
@@ -91,8 +92,7 @@ pub async fn add_tunnel(
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Parser, TS)]
|
||||
#[ts(export)]
|
||||
pub struct RemoveTunnelParams {
|
||||
#[ts(type = "string")]
|
||||
id: InternedString,
|
||||
id: GatewayId,
|
||||
}
|
||||
pub async fn remove_tunnel(
|
||||
ctx: RpcContext,
|
||||
|
||||
@@ -10,8 +10,10 @@ use color_eyre::eyre::eyre;
|
||||
use futures::FutureExt;
|
||||
use helpers::NonDetachingJoinHandle;
|
||||
use http::Uri;
|
||||
use imbl::OrdMap;
|
||||
use imbl_value::InternedString;
|
||||
use models::ResultExt;
|
||||
use itertools::Itertools;
|
||||
use models::{GatewayId, ResultExt};
|
||||
use rpc_toolkit::{from_fn, Context, HandlerArgs, HandlerExt, ParentHandler};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
@@ -31,13 +33,16 @@ use tracing::instrument;
|
||||
use ts_rs::TS;
|
||||
|
||||
use crate::context::{CliContext, RpcContext};
|
||||
use crate::db::model::public::NetworkInterfaceInfo;
|
||||
use crate::db::model::Database;
|
||||
use crate::net::acme::{AcmeCertCache, AcmeProvider};
|
||||
use crate::net::network_interface::{
|
||||
Accepted, NetworkInterfaceController, NetworkInterfaceListener,
|
||||
Accepted, AnyFilter, DynInterfaceFilter, InterfaceFilter, NetworkInterfaceController,
|
||||
NetworkInterfaceListener,
|
||||
};
|
||||
use crate::net::static_server::server_error;
|
||||
use crate::prelude::*;
|
||||
use crate::util::collections::EqSet;
|
||||
use crate::util::io::BackTrackingIO;
|
||||
use crate::util::serde::{display_serializable, HandlerExtSerde, MaybeUtf8String};
|
||||
use crate::util::sync::SyncMutex;
|
||||
@@ -51,12 +56,13 @@ pub fn vhost_api<C: Context>() -> ParentHandler<C> {
|
||||
use prettytable::*;
|
||||
|
||||
if let Some(format) = params.format {
|
||||
display_serializable(format, res);
|
||||
display_serializable(format, res)?;
|
||||
return Ok::<_, Error>(());
|
||||
}
|
||||
|
||||
let mut table = Table::new();
|
||||
table.add_row(row![bc => "FROM", "TO", "PUBLIC", "ACME", "CONNECT SSL", "ACTIVE"]);
|
||||
table
|
||||
.add_row(row![bc => "FROM", "TO", "GATEWAYS", "ACME", "CONNECT SSL", "ACTIVE"]);
|
||||
|
||||
for (external, targets) in res {
|
||||
for (host, targets) in targets {
|
||||
@@ -68,7 +74,7 @@ pub fn vhost_api<C: Context>() -> ParentHandler<C> {
|
||||
external.0
|
||||
),
|
||||
target.addr,
|
||||
target.public,
|
||||
target.gateways.iter().join(", "),
|
||||
target.acme.as_ref().map(|a| a.0.as_str()).unwrap_or("NONE"),
|
||||
target.connect_ssl.is_ok(),
|
||||
idx == 0
|
||||
@@ -117,12 +123,7 @@ impl VHostController {
|
||||
&self,
|
||||
hostname: Option<InternedString>,
|
||||
external: u16,
|
||||
TargetInfo {
|
||||
public,
|
||||
acme,
|
||||
addr,
|
||||
connect_ssl,
|
||||
}: TargetInfo,
|
||||
target: TargetInfo,
|
||||
) -> Result<Arc<()>, Error> {
|
||||
self.servers.mutate(|writable| {
|
||||
let server = if let Some(server) = writable.remove(&external) {
|
||||
@@ -136,15 +137,7 @@ impl VHostController {
|
||||
self.acme_tls_alpn_cache.clone(),
|
||||
)?
|
||||
};
|
||||
let rc = server.add(
|
||||
hostname,
|
||||
TargetInfo {
|
||||
public,
|
||||
acme,
|
||||
addr,
|
||||
connect_ssl,
|
||||
},
|
||||
);
|
||||
let rc = server.add(hostname, target);
|
||||
writable.insert(external, server);
|
||||
Ok(rc?)
|
||||
})
|
||||
@@ -152,8 +145,9 @@ impl VHostController {
|
||||
|
||||
pub fn dump_table(
|
||||
&self,
|
||||
) -> BTreeMap<JsonKey<u16>, BTreeMap<JsonKey<Option<InternedString>>, BTreeSet<TargetInfo>>>
|
||||
) -> BTreeMap<JsonKey<u16>, BTreeMap<JsonKey<Option<InternedString>>, EqSet<ShowTargetInfo>>>
|
||||
{
|
||||
let ip_info = self.interfaces.watcher.ip_info();
|
||||
self.servers.peek(|s| {
|
||||
s.iter()
|
||||
.map(|(k, v)| {
|
||||
@@ -167,8 +161,7 @@ impl VHostController {
|
||||
JsonKey::new(k.clone()),
|
||||
v.iter()
|
||||
.filter(|(_, v)| v.strong_count() > 0)
|
||||
.map(|(k, _)| k)
|
||||
.cloned()
|
||||
.map(|(k, _)| ShowTargetInfo::new(k.clone(), &ip_info))
|
||||
.collect(),
|
||||
)
|
||||
})
|
||||
@@ -192,14 +185,45 @@ impl VHostController {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct TargetInfo {
|
||||
pub public: bool,
|
||||
pub filter: DynInterfaceFilter,
|
||||
pub acme: Option<AcmeProvider>,
|
||||
pub addr: SocketAddr,
|
||||
pub connect_ssl: Result<(), AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct ShowTargetInfo {
|
||||
pub gateways: BTreeSet<GatewayId>,
|
||||
pub acme: Option<AcmeProvider>,
|
||||
pub addr: SocketAddr,
|
||||
pub connect_ssl: Result<(), AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn
|
||||
}
|
||||
impl ShowTargetInfo {
|
||||
pub fn new(
|
||||
TargetInfo {
|
||||
filter,
|
||||
acme,
|
||||
addr,
|
||||
connect_ssl,
|
||||
}: TargetInfo,
|
||||
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
|
||||
) -> Self {
|
||||
ShowTargetInfo {
|
||||
gateways: ip_info
|
||||
.iter()
|
||||
.filter(|(id, info)| filter.filter(*id, *info))
|
||||
.map(|(k, _)| k)
|
||||
.cloned()
|
||||
.collect(),
|
||||
acme,
|
||||
addr,
|
||||
connect_ssl,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export)]
|
||||
@@ -222,6 +246,21 @@ struct VHostServer {
|
||||
_thread: NonDetachingJoinHandle<()>,
|
||||
}
|
||||
|
||||
impl<'a> From<&'a BTreeMap<Option<InternedString>, BTreeMap<TargetInfo, Weak<()>>>> for AnyFilter {
|
||||
fn from(value: &'a BTreeMap<Option<InternedString>, BTreeMap<TargetInfo, Weak<()>>>) -> Self {
|
||||
Self(
|
||||
value
|
||||
.iter()
|
||||
.flat_map(|(_, v)| {
|
||||
v.iter()
|
||||
.filter(|(_, r)| r.strong_count() > 0)
|
||||
.map(|(t, _)| t.filter.clone())
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl VHostServer {
|
||||
async fn accept(
|
||||
listener: &mut NetworkInterfaceListener,
|
||||
@@ -233,35 +272,35 @@ impl VHostServer {
|
||||
let accepted;
|
||||
|
||||
loop {
|
||||
let any_public = mapping
|
||||
.borrow()
|
||||
.iter()
|
||||
.any(|(_, targets)| targets.iter().any(|(target, _)| target.public));
|
||||
let any_filter = AnyFilter::from(&*mapping.borrow());
|
||||
|
||||
let changed_public = mapping
|
||||
.wait_for(|m| {
|
||||
m.iter()
|
||||
.any(|(_, targets)| targets.iter().any(|(target, _)| target.public))
|
||||
!= any_public
|
||||
})
|
||||
let changed_filter = mapping
|
||||
.wait_for(|m| any_filter != AnyFilter::from(m))
|
||||
.boxed();
|
||||
|
||||
tokio::select! {
|
||||
a = listener.accept(any_public) => {
|
||||
a = listener.accept(&any_filter) => {
|
||||
accepted = a?;
|
||||
break;
|
||||
}
|
||||
_ = changed_public => {
|
||||
tracing::debug!("port {} {} public bindings", listener.port(), if any_public { "no longer has" } else { "now has" });
|
||||
_ = changed_filter => {
|
||||
tracing::debug!("port {} filter changed", listener.port());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let check = listener.check_filter();
|
||||
tokio::spawn(async move {
|
||||
let bind = accepted.bind;
|
||||
if let Err(e) =
|
||||
Self::handle_stream(accepted, mapping, db, acme_tls_alpn_cache, crypto_provider)
|
||||
.await
|
||||
if let Err(e) = Self::handle_stream(
|
||||
accepted,
|
||||
check,
|
||||
mapping,
|
||||
db,
|
||||
acme_tls_alpn_cache,
|
||||
crypto_provider,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!("Error in VHostController on {bind}: {e}");
|
||||
tracing::debug!("{e:?}")
|
||||
@@ -273,11 +312,11 @@ impl VHostServer {
|
||||
async fn handle_stream(
|
||||
Accepted {
|
||||
stream,
|
||||
is_public,
|
||||
wan_ip,
|
||||
bind,
|
||||
..
|
||||
}: Accepted,
|
||||
check_filter: impl FnOnce(SocketAddr, &DynInterfaceFilter) -> bool,
|
||||
mapping: watch::Receiver<Mapping>,
|
||||
db: TypedPatchDb<Database>,
|
||||
acme_tls_alpn_cache: AcmeTlsAlpnCache,
|
||||
@@ -431,10 +470,8 @@ impl VHostServer {
|
||||
.map(|(target, _)| target.clone())
|
||||
};
|
||||
if let Some(target) = target {
|
||||
if is_public && !target.public {
|
||||
log::warn!(
|
||||
"Rejecting connection from public interface to private bind: {bind} -> {target:?}"
|
||||
);
|
||||
if !check_filter(bind, &target.filter) {
|
||||
log::warn!("Connection from {bind} to {target:?} rejected by filter");
|
||||
return Ok(());
|
||||
}
|
||||
let peek = db.peek().await;
|
||||
@@ -660,7 +697,10 @@ impl VHostServer {
|
||||
crypto_provider: Arc<CryptoProvider>,
|
||||
acme_tls_alpn_cache: AcmeTlsAlpnCache,
|
||||
) -> Result<Self, Error> {
|
||||
let mut listener = iface_ctrl.bind(port).with_kind(crate::ErrorKind::Network)?;
|
||||
let mut listener = iface_ctrl
|
||||
.watcher
|
||||
.bind(port)
|
||||
.with_kind(crate::ErrorKind::Network)?;
|
||||
let (map_send, map_recv) = watch::channel(BTreeMap::new());
|
||||
Ok(Self {
|
||||
mapping: map_send,
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use std::future::Future;
|
||||
use std::net::SocketAddr;
|
||||
use std::ops::Deref;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::sync::Arc;
|
||||
use std::task::Poll;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -16,7 +15,7 @@ use tokio::sync::oneshot;
|
||||
|
||||
use crate::context::{DiagnosticContext, InitContext, InstallContext, RpcContext, SetupContext};
|
||||
use crate::net::network_interface::{
|
||||
NetworkInterfaceListener, SelfContainedNetworkInterfaceListener,
|
||||
lookup_info_by_addr, NetworkInterfaceListener, SelfContainedNetworkInterfaceListener,
|
||||
};
|
||||
use crate::net::static_server::{
|
||||
diagnostic_ui_router, init_ui_router, install_ui_router, main_ui_router, redirecter, refresher,
|
||||
@@ -50,10 +49,15 @@ impl Accept for Vec<TcpListener> {
|
||||
}
|
||||
impl Accept for NetworkInterfaceListener {
|
||||
fn poll_accept(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>> {
|
||||
NetworkInterfaceListener::poll_accept(self, cx, true).map(|res| {
|
||||
res.map(|a| Accepted {
|
||||
https_redirect: a.is_public,
|
||||
stream: a.stream,
|
||||
NetworkInterfaceListener::poll_accept(self, cx, &true).map(|res| {
|
||||
res.map(|a| {
|
||||
let public = self
|
||||
.ip_info
|
||||
.peek(|i| lookup_info_by_addr(i, a.bind).map_or(true, |(_, i)| i.public()));
|
||||
Accepted {
|
||||
https_redirect: public,
|
||||
stream: a.stream,
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -89,5 +89,6 @@ pub async fn get_service_port_forward(
|
||||
.de()?
|
||||
.get(&internal_port)
|
||||
.or_not_found(lazy_format!("binding for port {internal_port}"))?
|
||||
.net)
|
||||
.net
|
||||
.clone())
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ use crate::net::ssl::root_ca_start_time;
|
||||
use crate::prelude::*;
|
||||
use crate::progress::{FullProgress, PhaseProgressTrackerHandle, ProgressUnits};
|
||||
use crate::rpc_continuations::Guid;
|
||||
use crate::shutdown::Shutdown;
|
||||
use crate::system::sync_kiosk;
|
||||
use crate::util::crypto::EncryptedWire;
|
||||
use crate::util::io::{create_file, dir_copy, dir_size, Counter};
|
||||
@@ -67,6 +68,7 @@ pub fn setup<C: Context>() -> ParentHandler<C> {
|
||||
"logs",
|
||||
from_fn_async(crate::logs::cli_logs::<SetupContext, Empty>).no_display(),
|
||||
)
|
||||
.subcommand("restart", from_fn_async(restart).no_cli())
|
||||
}
|
||||
|
||||
pub fn disk<C: Context>() -> ParentHandler<C> {
|
||||
@@ -172,6 +174,7 @@ pub async fn attach(
|
||||
if disk_guid.ends_with("_UNENC") { None } else { Some(DEFAULT_PASSWORD) },
|
||||
)
|
||||
.await?;
|
||||
let _ = setup_ctx.disk_guid.set(disk_guid.clone());
|
||||
if tokio::fs::metadata(REPAIR_DISK_PATH).await.is_ok() {
|
||||
tokio::fs::remove_file(REPAIR_DISK_PATH)
|
||||
.await
|
||||
@@ -390,9 +393,19 @@ pub async fn complete(ctx: SetupContext) -> Result<SetupResult, Error> {
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
// #[command(rpc_only)]
|
||||
pub async fn exit(ctx: SetupContext) -> Result<(), Error> {
|
||||
ctx.shutdown.send(()).expect("failed to shutdown");
|
||||
ctx.shutdown.send(None).expect("failed to shutdown");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub async fn restart(ctx: SetupContext) -> Result<(), Error> {
|
||||
ctx.shutdown
|
||||
.send(Some(Shutdown {
|
||||
disk_guid: ctx.disk_guid.get().cloned(),
|
||||
restart: true,
|
||||
}))
|
||||
.expect("failed to shutdown");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -435,6 +448,7 @@ pub async fn execute_inner(
|
||||
);
|
||||
let _ = crate::disk::main::import(&*guid, DATA_DIR, RepairStrategy::Preen, encryption_password)
|
||||
.await?;
|
||||
let _ = ctx.disk_guid.set(guid.clone());
|
||||
disk_phase.complete();
|
||||
|
||||
let progress = SetupExecuteProgress {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::context::RpcContext;
|
||||
@@ -7,11 +6,11 @@ use crate::init::{STANDBY_MODE_PATH, SYSTEM_REBUILD_PATH};
|
||||
use crate::prelude::*;
|
||||
use crate::sound::SHUTDOWN;
|
||||
use crate::util::Invoke;
|
||||
use crate::{DATA_DIR, PLATFORM};
|
||||
use crate::PLATFORM;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Shutdown {
|
||||
pub export_args: Option<(Arc<String>, PathBuf)>,
|
||||
pub disk_guid: Option<Arc<String>>,
|
||||
pub restart: bool,
|
||||
}
|
||||
impl Shutdown {
|
||||
@@ -41,8 +40,8 @@ impl Shutdown {
|
||||
tracing::error!("Error Stopping Journald: {}", e);
|
||||
tracing::debug!("{:?}", e);
|
||||
}
|
||||
if let Some((guid, datadir)) = &self.export_args {
|
||||
if let Err(e) = export(guid, datadir).await {
|
||||
if let Some(guid) = &self.disk_guid {
|
||||
if let Err(e) = export(guid, crate::DATA_DIR).await {
|
||||
tracing::error!("Error Exporting Volume Group: {}", e);
|
||||
tracing::debug!("{:?}", e);
|
||||
}
|
||||
@@ -87,7 +86,7 @@ pub async fn shutdown(ctx: RpcContext) -> Result<(), Error> {
|
||||
.result?;
|
||||
ctx.shutdown
|
||||
.send(Some(Shutdown {
|
||||
export_args: Some((ctx.disk_guid.clone(), Path::new(DATA_DIR).to_owned())),
|
||||
disk_guid: Some(ctx.disk_guid.clone()),
|
||||
restart: false,
|
||||
}))
|
||||
.map_err(|_| eyre!("receiver dropped"))
|
||||
@@ -108,7 +107,7 @@ pub async fn restart(ctx: RpcContext) -> Result<(), Error> {
|
||||
.result?;
|
||||
ctx.shutdown
|
||||
.send(Some(Shutdown {
|
||||
export_args: Some((ctx.disk_guid.clone(), Path::new(DATA_DIR).to_owned())),
|
||||
disk_guid: Some(ctx.disk_guid.clone()),
|
||||
restart: true,
|
||||
}))
|
||||
.map_err(|_| eyre!("receiver dropped"))
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use std::collections::BTreeSet;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
|
||||
use std::ops::Deref;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use clap::Parser;
|
||||
use imbl::OrdMap;
|
||||
use imbl_value::InternedString;
|
||||
use patch_db::PatchDb;
|
||||
use rpc_toolkit::yajrc::RpcError;
|
||||
@@ -15,13 +16,15 @@ use tracing::instrument;
|
||||
|
||||
use crate::auth::{check_password, Sessions};
|
||||
use crate::context::config::ContextConfig;
|
||||
use crate::context::{CliContext, RpcContext};
|
||||
use crate::context::CliContext;
|
||||
use crate::middleware::auth::AuthContext;
|
||||
use crate::middleware::signature::SignatureAuthContext;
|
||||
use crate::net::forward::PortForwardController;
|
||||
use crate::net::network_interface::NetworkInterfaceWatcher;
|
||||
use crate::prelude::*;
|
||||
use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations};
|
||||
use crate::tunnel::{TunnelDatabase, TUNNEL_DEFAULT_PORT};
|
||||
use crate::util::iter::TransposeResultIterExt;
|
||||
use crate::tunnel::db::TunnelDatabase;
|
||||
use crate::tunnel::TUNNEL_DEFAULT_PORT;
|
||||
use crate::util::sync::SyncMutex;
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize, Parser)]
|
||||
@@ -62,6 +65,8 @@ pub struct TunnelContextSeed {
|
||||
pub rpc_continuations: RpcContinuations,
|
||||
pub open_authed_continuations: OpenAuthedContinuations<Option<InternedString>>,
|
||||
pub ephemeral_sessions: SyncMutex<Sessions>,
|
||||
pub net_iface: NetworkInterfaceWatcher,
|
||||
pub forward: PortForwardController,
|
||||
pub shutdown: Sender<()>,
|
||||
}
|
||||
|
||||
@@ -89,6 +94,8 @@ impl TunnelContext {
|
||||
Ipv6Addr::UNSPECIFIED.into(),
|
||||
TUNNEL_DEFAULT_PORT,
|
||||
));
|
||||
let net_iface = NetworkInterfaceWatcher::new(async { OrdMap::new() }, []);
|
||||
let forward = PortForwardController::new(net_iface.subscribe());
|
||||
Ok(Self(Arc::new(TunnelContextSeed {
|
||||
listen,
|
||||
addrs: crate::net::utils::all_socket_addrs_for(listen.port())
|
||||
@@ -101,6 +108,8 @@ impl TunnelContext {
|
||||
rpc_continuations: RpcContinuations::new(),
|
||||
open_authed_continuations: OpenAuthedContinuations::new(),
|
||||
ephemeral_sessions: SyncMutex::new(Sessions::new()),
|
||||
net_iface,
|
||||
forward,
|
||||
shutdown,
|
||||
})))
|
||||
}
|
||||
@@ -213,14 +222,3 @@ impl CallRemote<TunnelContext> for CliContext {
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl CallRemote<TunnelContext, TunnelAddrParams> for RpcContext {
|
||||
async fn call_remote(
|
||||
&self,
|
||||
mut method: &str,
|
||||
params: Value,
|
||||
TunnelAddrParams { tunnel }: TunnelAddrParams,
|
||||
) -> Result<Value, RpcError> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::collections::{BTreeMap, HashSet};
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::path::PathBuf;
|
||||
|
||||
use clap::Parser;
|
||||
@@ -10,12 +12,31 @@ use serde::{Deserialize, Serialize};
|
||||
use tracing::instrument;
|
||||
use ts_rs::TS;
|
||||
|
||||
use crate::auth::Sessions;
|
||||
use crate::context::CliContext;
|
||||
use crate::prelude::*;
|
||||
use crate::sign::AnyVerifyingKey;
|
||||
use crate::tunnel::context::TunnelContext;
|
||||
use crate::tunnel::TunnelDatabase;
|
||||
use crate::util::serde::{apply_expr, HandlerExtSerde};
|
||||
|
||||
#[derive(Debug, Default, Deserialize, Serialize, HasModel)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[model = "Model<Self>"]
|
||||
pub struct TunnelDatabase {
|
||||
pub sessions: Sessions,
|
||||
pub password: String,
|
||||
pub auth_pubkeys: HashSet<AnyVerifyingKey>,
|
||||
pub clients: BTreeMap<Ipv4Addr, ClientInfo>,
|
||||
pub port_forwards: BTreeMap<SocketAddr, SocketAddr>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize, Serialize, HasModel)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[model = "Model<Self>"]
|
||||
pub struct ClientInfo {
|
||||
pub server: bool,
|
||||
}
|
||||
|
||||
pub fn db_api<C: Context>() -> ParentHandler<C> {
|
||||
ParentHandler::new()
|
||||
.subcommand(
|
||||
|
||||
0
core/startos/src/tunnel/init.rs
Normal file
0
core/startos/src/tunnel/init.rs
Normal file
@@ -1,4 +1,5 @@
|
||||
use std::collections::HashSet;
|
||||
use std::collections::{BTreeMap, HashSet};
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
|
||||
use axum::Router;
|
||||
use futures::future::ready;
|
||||
@@ -17,18 +18,10 @@ use crate::tunnel::context::TunnelContext;
|
||||
|
||||
pub mod context;
|
||||
pub mod db;
|
||||
pub mod init;
|
||||
|
||||
pub const TUNNEL_DEFAULT_PORT: u16 = 5960;
|
||||
|
||||
#[derive(Debug, Default, Deserialize, Serialize, HasModel)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[model = "Model<Self>"]
|
||||
pub struct TunnelDatabase {
|
||||
pub sessions: Sessions,
|
||||
pub password: String,
|
||||
pub auth_pubkeys: HashSet<AnyVerifyingKey>,
|
||||
}
|
||||
|
||||
pub fn tunnel_api<C: Context>() -> ParentHandler<C> {
|
||||
ParentHandler::new().subcommand(
|
||||
"db",
|
||||
|
||||
@@ -359,7 +359,7 @@ impl UploadHandle {
|
||||
});
|
||||
}
|
||||
}
|
||||
async fn process_body<E: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>>(
|
||||
async fn process_body<E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>>(
|
||||
&mut self,
|
||||
mut body: impl Stream<Item = Result<Bytes, E>> + Unpin,
|
||||
) {
|
||||
|
||||
@@ -666,13 +666,27 @@ impl<K: Eq, V> IntoIterator for EqMap<K, V> {
|
||||
|
||||
impl<K: Eq, V> Extend<(K, V)> for EqMap<K, V> {
|
||||
fn extend<T: IntoIterator<Item = (K, V)>>(&mut self, iter: T) {
|
||||
self.0.extend(iter)
|
||||
let iter = iter.into_iter();
|
||||
if let (_, Some(len)) = iter.size_hint() {
|
||||
self.0.reserve(len)
|
||||
}
|
||||
for (k, v) in iter {
|
||||
self.insert(k, v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Eq, V> FromIterator<(K, V)> for EqMap<K, V> {
|
||||
fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
|
||||
Self(Vec::from_iter(iter))
|
||||
let mut res = Self(Vec::new());
|
||||
let iter = iter.into_iter();
|
||||
if let (_, Some(len)) = iter.size_hint() {
|
||||
res.0.reserve(len)
|
||||
}
|
||||
for (k, v) in iter {
|
||||
res.insert(k, v);
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
@@ -687,7 +701,7 @@ impl<K: Eq, V, const N: usize> From<[(K, V); N]> for EqMap<K, V> {
|
||||
/// assert_eq!(map1, map2);
|
||||
/// ```
|
||||
fn from(arr: [(K, V); N]) -> Self {
|
||||
EqMap(Vec::from(arr))
|
||||
Self::from_iter(arr)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
425
core/startos/src/util/collections/eq_set.rs
Normal file
425
core/startos/src/util/collections/eq_set.rs
Normal file
@@ -0,0 +1,425 @@
|
||||
use std::borrow::Borrow;
|
||||
use std::fmt;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Serialize)]
|
||||
pub struct EqSet<T: Eq>(Vec<T>);
|
||||
impl<T: Eq> Default for EqSet<T> {
|
||||
fn default() -> Self {
|
||||
Self(Default::default())
|
||||
}
|
||||
}
|
||||
impl<T: Eq> EqSet<T> {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.0.clear()
|
||||
}
|
||||
|
||||
/// Returns a reference to the element in the set, if any, that is equal to
|
||||
/// the value.
|
||||
///
|
||||
/// The value may be any borrowed form of the set's element type,
|
||||
/// but the ordering on the borrowed form *must* match the
|
||||
/// ordering on the element type.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use startos::util::collections::EqSet;
|
||||
///
|
||||
/// let set = EqSet::from([1, 2, 3]);
|
||||
/// assert_eq!(set.get(&2), Some(&2));
|
||||
/// assert_eq!(set.get(&4), None);
|
||||
/// ```
|
||||
pub fn get<Q: ?Sized>(&self, value: &Q) -> Option<&T>
|
||||
where
|
||||
T: Borrow<Q>,
|
||||
Q: Eq,
|
||||
{
|
||||
self.0.iter().find(|k| (*k).borrow() == value)
|
||||
}
|
||||
|
||||
/// Removes and returns an element in the set.
|
||||
/// There is no guarantee about which element this might be
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use startos::util::collections::EqSet;
|
||||
///
|
||||
/// let mut set = EqSet::new();
|
||||
/// set.insert("a");
|
||||
/// set.insert("b");
|
||||
/// while let Some(_val) = set.pop() { }
|
||||
/// assert!(set.is_empty());
|
||||
/// ```
|
||||
pub fn pop(&mut self) -> Option<T> {
|
||||
self.0.pop()
|
||||
}
|
||||
|
||||
/// Returns `true` if the set contains a value for the specified value.
|
||||
///
|
||||
/// The value may be any borrowed form of the set's value type, but the equality
|
||||
/// on the borrowed form *must* match the equality on the value type.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use startos::util::collections::EqSet;
|
||||
///
|
||||
/// let mut set = EqSet::new();
|
||||
/// set.insert("a");
|
||||
/// assert_eq!(set.contains("a"), true);
|
||||
/// assert_eq!(set.contains("b"), false);
|
||||
/// ```
|
||||
pub fn contains<Q: ?Sized>(&self, value: &Q) -> bool
|
||||
where
|
||||
T: Borrow<Q>,
|
||||
Q: Eq,
|
||||
{
|
||||
self.get(value).is_some()
|
||||
}
|
||||
|
||||
/// Inserts a value into the set.
|
||||
///
|
||||
/// If the set did not have this value present, `None` is returned.
|
||||
///
|
||||
/// If the set did have this value present, the value is updated, and the old
|
||||
/// value is returned.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use startos::util::collections::EqSet;
|
||||
///
|
||||
/// let mut set = EqSet::new();
|
||||
/// assert_eq!(set.insert("a"), None);
|
||||
/// assert_eq!(set.is_empty(), false);
|
||||
///
|
||||
/// set.insert("b");
|
||||
/// assert_eq!(set.insert("b"), Some("b"));
|
||||
/// assert!(set.contains("a"));
|
||||
/// ```
|
||||
pub fn insert(&mut self, value: T) -> Option<T> {
|
||||
if let Some(entry) = self.0.iter_mut().find(|a| *a == &value) {
|
||||
Some(std::mem::replace(entry, value))
|
||||
} else {
|
||||
self.0.push(value);
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Tries to insert a value into the set.
|
||||
///
|
||||
/// If the set already had this value present, nothing is updated.
|
||||
///
|
||||
/// Returns whether the value was inserted.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use startos::util::collections::EqSet;
|
||||
///
|
||||
/// let mut set = EqSet::new();
|
||||
/// assert!(set.try_insert("a"));
|
||||
/// assert!(!set.try_insert("a"));
|
||||
/// ```
|
||||
pub fn try_insert(&mut self, value: T) -> bool {
|
||||
if self.0.iter().find(|a| *a == &value).is_some() {
|
||||
false
|
||||
} else {
|
||||
self.0.push(value);
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a value from the set, returning the value if it
|
||||
/// was previously in the set.
|
||||
///
|
||||
/// The value may be any borrowed form of the set's value type, but the equality
|
||||
/// on the borrowed form *must* match the equality on the value type.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use startos::util::collections::EqSet;
|
||||
///
|
||||
/// let mut set = EqSet::new();
|
||||
/// set.insert("a");
|
||||
/// assert_eq!(set.remove("a"), Some("a"));
|
||||
/// assert_eq!(set.remove("a"), None);
|
||||
/// ```
|
||||
pub fn remove<Q: ?Sized>(&mut self, value: &Q) -> Option<T>
|
||||
where
|
||||
T: Borrow<Q>,
|
||||
Q: Eq,
|
||||
{
|
||||
if let Some((idx, _)) = self
|
||||
.0
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, v)| (*v).borrow() == value)
|
||||
{
|
||||
Some(self.0.swap_remove(idx))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Retains only the elements specified by the predicate.
|
||||
///
|
||||
/// In other words, remove all pairs `(k, v)` for which `f(&k, &mut v)` returns `false`.
|
||||
/// The elements are visited in ascending value order.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use startos::util::collections::EqSet;
|
||||
///
|
||||
/// let mut set: EqSet<i32, i32> = (0..8).set(|x| (x, x*10)).collect();
|
||||
/// // Keep only the elements with even-numbered values.
|
||||
/// set.retain(|&k, _| k % 2 == 0);
|
||||
/// assert!(set.into_iter().eq(vec![(0, 0), (2, 20), (4, 40), (6, 60)]));
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn retain<F>(&mut self, f: F)
|
||||
where
|
||||
F: FnMut(&T) -> bool,
|
||||
{
|
||||
self.0.retain(f)
|
||||
}
|
||||
|
||||
/// Moves all elements from `other` into `self`, leaving `other` empty.
|
||||
///
|
||||
/// If a value from `other` is already present in `self`, the respective
|
||||
/// value from `self` will be overwritten with the respective value from `other`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use startos::util::collections::EqSet;
|
||||
///
|
||||
/// let mut a = EqSet::new();
|
||||
/// a.insert("a");
|
||||
/// a.insert("b");
|
||||
/// a.insert("c"); // Note: "c" also present in b.
|
||||
///
|
||||
/// let mut b = EqSet::new();
|
||||
/// b.insert(3, "c"); // Note: "c" also present in a.
|
||||
/// b.insert(4, "d");
|
||||
/// b.insert(5, "e");
|
||||
///
|
||||
/// a.append(&mut b);
|
||||
///
|
||||
/// assert_eq!(a.len(), 5);
|
||||
/// assert_eq!(b.len(), 0);
|
||||
/// ```
|
||||
pub fn append(&mut self, other: &mut Self) {
|
||||
other.retain(|v| !self.contains(v));
|
||||
self.0.append(&mut other.0)
|
||||
}
|
||||
|
||||
// /// Creates an iterator that visits all elements (values) and
|
||||
// /// uses a closure to determine if an element should be removed. If the
|
||||
// /// closure returns `true`, the element is removed from the set and yielded.
|
||||
// /// If the closure returns `false`, or panics, the element remains in the set
|
||||
// /// and will not be yielded.
|
||||
// ///
|
||||
// /// The iterator also lets you mutate the value of each element in the
|
||||
// /// closure, regardless of whether you choose to keep or remove it.
|
||||
// ///
|
||||
// /// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating
|
||||
// /// or the iteration short-circuits, then the remaining elements will be retained.
|
||||
// /// Use [`retain`] with a negated predicate if you do not need the returned iterator.
|
||||
// ///
|
||||
// /// [`retain`]: EqSet::retain
|
||||
// ///
|
||||
// /// # Examples
|
||||
// ///
|
||||
// /// Splitting a set into even and odd values, reusing the original set:
|
||||
// ///
|
||||
// /// ```
|
||||
// /// use startos::util::collections::EqSet;
|
||||
// ///
|
||||
// /// let mut set: EqSet<i32, i32> = (0..8).set(|x| (x, x)).collect();
|
||||
// /// let evens: EqSet<_, _> = set.extract_if(|k, _v| k % 2 == 0).collect();
|
||||
// /// let odds = set;
|
||||
// /// assert_eq!(evens.values().copied().collect::<Vec<_>>(), [0, 2, 4, 6]);
|
||||
// /// assert_eq!(odds.values().copied().collect::<Vec<_>>(), [1, 3, 5, 7]);
|
||||
// /// ```
|
||||
// pub fn extract_if<F>(&mut self, pred: F) -> ExtractIf<'_, T, F>
|
||||
// where
|
||||
// K: Eq,
|
||||
// F: FnMut(&K, &mut V) -> bool,
|
||||
// {
|
||||
// let (inner, alloc) = self.extract_if_inner();
|
||||
// ExtractIf { pred, inner, alloc }
|
||||
// }
|
||||
|
||||
/// Gets an iterator over the entries of the set, in no particular order.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use startos::util::collections::EqSet;
|
||||
///
|
||||
/// let mut set = EqSet::new();
|
||||
/// set.insert("c");
|
||||
/// set.insert("b");
|
||||
/// set.insert("a");
|
||||
///
|
||||
/// for value in set.iter() {
|
||||
/// println!("{value}");
|
||||
/// }
|
||||
///
|
||||
/// let first_value = set.iter().next().unwrap();
|
||||
/// assert_eq!(*first_value, "c");
|
||||
/// ```
|
||||
pub fn iter(&self) -> std::slice::Iter<'_, T> {
|
||||
self.0.iter()
|
||||
}
|
||||
|
||||
/// Returns the number of elements in the set.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use startos::util::collections::EqSet;
|
||||
///
|
||||
/// let mut a = EqSet::new();
|
||||
/// assert_eq!(a.len(), 0);
|
||||
/// a.insert("a");
|
||||
/// assert_eq!(a.len(), 1);
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
/// Returns `true` if the set contains no elements.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use startos::util::collections::EqSet;
|
||||
///
|
||||
/// let mut a = EqSet::new();
|
||||
/// assert!(a.is_empty());
|
||||
/// a.insert("a");
|
||||
/// assert!(!a.is_empty());
|
||||
/// ```
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: fmt::Debug + Eq> fmt::Debug for EqSet<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_set().entries(self.iter()).finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Eq> IntoIterator for EqSet<T> {
|
||||
type IntoIter = std::vec::IntoIter<T>;
|
||||
type Item = T;
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.0.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Eq> Extend<T> for EqSet<T> {
|
||||
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
|
||||
let iter = iter.into_iter();
|
||||
if let (_, Some(len)) = iter.size_hint() {
|
||||
self.0.reserve(len)
|
||||
}
|
||||
for v in iter {
|
||||
self.insert(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Eq> FromIterator<T> for EqSet<T> {
|
||||
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
|
||||
let mut res = Self(Vec::new());
|
||||
let iter = iter.into_iter();
|
||||
if let (_, Some(len)) = iter.size_hint() {
|
||||
res.0.reserve(len)
|
||||
}
|
||||
for v in iter {
|
||||
res.insert(v);
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Eq, const N: usize> From<[T; N]> for EqSet<T> {
|
||||
/// Converts a `[T; N]` into a `EqSet<T>`.
|
||||
///
|
||||
/// ```
|
||||
/// use startos::util::collections::EqSet;
|
||||
///
|
||||
/// let set1 = EqSet::from([(1, 2), (3, 4)]);
|
||||
/// let set2: EqSet<_, _> = [(1, 2), (3, 4)].into();
|
||||
/// assert_eq!(set1, set2);
|
||||
/// ```
|
||||
fn from(arr: [T; N]) -> Self {
|
||||
EqSet::from_iter(arr)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Eq> PartialEq for EqSet<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.len() == other.len() && self.iter().all(|v| other.get(v) == Some(v))
|
||||
}
|
||||
}
|
||||
impl<T: Eq> Eq for EqSet<T> {}
|
||||
|
||||
impl<'de, T> Deserialize<'de> for EqSet<T>
|
||||
where
|
||||
T: Deserialize<'de> + Eq,
|
||||
{
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct Visitor<T> {
|
||||
marker: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<'de, T> serde::de::Visitor<'de> for Visitor<T>
|
||||
where
|
||||
T: Deserialize<'de> + Eq,
|
||||
{
|
||||
type Value = EqSet<T>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("a sequence")
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::SeqAccess<'de>,
|
||||
{
|
||||
let mut values = EqSet(Vec::new());
|
||||
|
||||
while let Some(value) = seq.next_element()? {
|
||||
values.insert(value);
|
||||
}
|
||||
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
|
||||
let visitor = Visitor {
|
||||
marker: PhantomData,
|
||||
};
|
||||
deserializer.deserialize_seq(visitor)
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,44 @@
|
||||
pub mod eq_map;
|
||||
pub mod eq_set;
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub use eq_map::EqMap;
|
||||
pub use eq_set::EqSet;
|
||||
use imbl::OrdMap;
|
||||
|
||||
pub struct OrdMapIterMut<'a, K: 'a, V: 'a> {
|
||||
map: *mut OrdMap<K, V>,
|
||||
prev: Option<&'a K>,
|
||||
_marker: PhantomData<&'a mut (K, V)>,
|
||||
}
|
||||
impl<'a, K, V> From<&'a mut OrdMap<K, V>> for OrdMapIterMut<'a, K, V> {
|
||||
fn from(value: &'a mut OrdMap<K, V>) -> Self {
|
||||
Self {
|
||||
map: value,
|
||||
prev: None,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'a, K: Ord + Clone, V: Clone> Iterator for OrdMapIterMut<'a, K, V> {
|
||||
type Item = (&'a K, &'a mut V);
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
unsafe {
|
||||
let map: &'a mut OrdMap<K, V> = self.map.as_mut().unwrap();
|
||||
let res = if let Some(k) = self.prev.take() {
|
||||
map.get_next_mut(k)
|
||||
} else {
|
||||
let Some((k, _)) = map.get_min() else {
|
||||
return None;
|
||||
};
|
||||
let k = k.clone(); // hate that I have to do this but whatev
|
||||
map.get_key_value_mut(&k)
|
||||
};
|
||||
if let Some((k, _)) = &res {
|
||||
self.prev = Some(*k);
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use base64::Engine;
|
||||
use clap::builder::ValueParserFactory;
|
||||
use clap::{ArgMatches, CommandFactory, FromArgMatches};
|
||||
use color_eyre::eyre::eyre;
|
||||
use imbl::OrdMap;
|
||||
use imbl_value::imbl::OrdMap;
|
||||
use models::FromStrParser;
|
||||
use openssl::pkey::{PKey, Private};
|
||||
use openssl::x509::X509;
|
||||
|
||||
@@ -140,6 +140,9 @@ impl<T: Clone> Watch<T> {
|
||||
pub fn read(&self) -> T {
|
||||
self.peek(|a| a.clone())
|
||||
}
|
||||
pub fn read_and_mark_seen(&mut self) -> T {
|
||||
self.peek_and_mark_seen(|a| a.clone())
|
||||
}
|
||||
}
|
||||
impl<T: Clone> futures::Stream for Watch<T> {
|
||||
type Item = T;
|
||||
|
||||
Reference in New Issue
Block a user