refactor complete

This commit is contained in:
Aiden McClelland
2025-10-29 11:17:49 -06:00
parent 124ed625d9
commit 5580ff6f01
40 changed files with 1171 additions and 907 deletions

54
core/Cargo.lock generated
View File

@@ -156,6 +156,15 @@ version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
[[package]]
name = "ansi-width"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "219e3ce6f2611d83b51ec2098a12702112c29e57203a6b0a0929b2cddb486608"
dependencies = [
"unicode-width 0.1.14",
]
[[package]]
name = "anstream"
version = "0.6.21"
@@ -1777,16 +1786,18 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "crossterm"
version = "0.27.0"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df"
checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b"
dependencies = [
"bitflags 2.10.0",
"crossterm_winapi",
"derive_more 2.0.1",
"document-features",
"futures-core",
"libc",
"mio 0.8.11",
"mio",
"parking_lot",
"rustix 1.1.2",
"signal-hook",
"signal-hook-mio",
"winapi",
@@ -4535,26 +4546,14 @@ dependencies = [
[[package]]
name = "mio"
version = "0.8.11"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c"
checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c"
dependencies = [
"libc",
"log",
"wasi 0.11.1+wasi-snapshot-preview1",
"windows-sys 0.48.0",
]
[[package]]
name = "mio"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69d83b0086dc8ecf3ce9ae2874b2d1290252e2a30720bea58a5c6639b0092873"
dependencies = [
"libc",
"log",
"wasi 0.11.1+wasi-snapshot-preview1",
"windows-sys 0.61.2",
"windows-sys 0.59.0",
]
[[package]]
@@ -4726,7 +4725,7 @@ dependencies = [
"kqueue",
"libc",
"log",
"mio 1.1.0",
"mio",
"notify-types",
"walkdir",
"windows-sys 0.60.2",
@@ -6440,18 +6439,17 @@ dependencies = [
[[package]]
name = "rustyline-async"
version = "0.4.2"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b6eb06391513b2184f0a5405c11a4a0a5302e8be442f4c5c35267187c2b37d5"
checksum = "6e07ddce8399c61495b405dc94d4f30d01fc1c5e1238f10b9c09940678bc81ab"
dependencies = [
"ansi-width",
"crossterm",
"futures-channel",
"futures-util",
"pin-project",
"thingbuf",
"thiserror 1.0.69",
"thiserror 2.0.17",
"unicode-segmentation",
"unicode-width 0.1.14",
]
[[package]]
@@ -6907,7 +6905,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd"
dependencies = [
"libc",
"mio 0.8.11",
"mio",
"signal-hook",
]
@@ -7338,7 +7336,7 @@ dependencies = [
"libc",
"log",
"mbrman",
"mio 1.1.0",
"mio",
"models",
"new_mime_guess",
"nix 0.30.1",
@@ -7793,7 +7791,7 @@ checksum = "ff360e02eab121e0bc37a2d3b4d4dc622e6eda3a8e5253d5435ecf5bd4c68408"
dependencies = [
"bytes",
"libc",
"mio 1.1.0",
"mio",
"parking_lot",
"pin-project-lite",
"signal-hook-registry",

View File

@@ -187,7 +187,6 @@ impl Display for ErrorKind {
}
}
#[derive(Debug)]
pub struct Error {
pub source: color_eyre::eyre::Error,
pub debug: Option<color_eyre::eyre::Error>,
@@ -198,7 +197,17 @@ pub struct Error {
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.kind.as_str(), self.source)
write!(f, "{}: {:#}", self.kind.as_str(), self.source)
}
}
impl Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}: {:?}",
self.kind.as_str(),
self.debug.as_ref().unwrap_or(&self.source)
)
}
}
impl Error {

View File

@@ -222,7 +222,7 @@ reqwest_cookie_store = "0.8.0"
rpassword = "7.2.0"
rpc-toolkit = { git = "https://github.com/Start9Labs/rpc-toolkit.git", branch = "master" }
rust-argon2 = "2.0.0"
rustyline-async = "0.4.1"
rustyline-async = "0.4.7"
safelog = { version = "0.4.8", git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit", optional = true }
semver = { version = "1.0.20", features = ["serde"] }
serde = { version = "1.0", features = ["derive", "rc"] }

View File

@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::time::SystemTime;
use imbl_value::InternedString;
@@ -120,12 +121,20 @@ impl AccountInfo {
key_store.as_onion_mut().insert_key(tor_key)?;
}
let cert_store = key_store.as_local_certs_mut();
cert_store
.as_root_key_mut()
.ser(Pem::new_ref(&self.root_ca_key))?;
cert_store
.as_root_cert_mut()
.ser(Pem::new_ref(&self.root_ca_cert))?;
if cert_store.as_root_cert().de()?.0 != self.root_ca_cert {
cert_store
.as_root_key_mut()
.ser(Pem::new_ref(&self.root_ca_key))?;
cert_store
.as_root_cert_mut()
.ser(Pem::new_ref(&self.root_ca_cert))?;
let int_key = crate::net::ssl::generate_key()?;
let int_cert =
crate::net::ssl::make_int_cert((&self.root_ca_key, &self.root_ca_cert), &int_key)?;
cert_store.as_int_key_mut().ser(&Pem(int_key))?;
cert_store.as_int_cert_mut().ser(&Pem(int_cert))?;
cert_store.as_leaves_mut().ser(&BTreeMap::new())?;
}
Ok(())
}

View File

@@ -6,19 +6,22 @@ use rpc_toolkit::CliApp;
use tokio::signal::unix::signal;
use tracing::instrument;
use crate::context::CliContext;
use crate::context::config::ClientConfig;
use crate::context::CliContext;
use crate::net::web_server::{Acceptor, WebServer};
use crate::prelude::*;
use crate::registry::context::{RegistryConfig, RegistryContext};
use crate::registry::registry_router;
use crate::util::logger::LOGGER;
#[instrument(skip_all)]
async fn inner_main(config: &RegistryConfig) -> Result<(), Error> {
let server = async {
let ctx = RegistryContext::init(config).await?;
let mut server = WebServer::new(Acceptor::bind([ctx.listen]).await?);
server.serve_registry(ctx.clone());
let server = WebServer::new(
Acceptor::bind([ctx.listen]).await?,
registry_router(ctx.clone()),
);
let mut shutdown_recv = ctx.shutdown.subscribe();

View File

@@ -11,7 +11,8 @@ use crate::disk::main::DEFAULT_PASSWORD;
use crate::disk::REPAIR_DISK_PATH;
use crate::firmware::{check_for_firmware_update, update_firmware};
use crate::init::{InitPhases, STANDBY_MODE_PATH};
use crate::net::web_server::{UpgradableListener, WebServer};
use crate::net::gateway::UpgradableListener;
use crate::net::web_server::WebServer;
use crate::prelude::*;
use crate::progress::FullProgressTracker;
use crate::shutdown::Shutdown;

View File

@@ -12,8 +12,9 @@ use tracing::instrument;
use crate::context::config::ServerConfig;
use crate::context::rpc::InitRpcContextPhases;
use crate::context::{DiagnosticContext, InitContext, RpcContext};
use crate::net::gateway::SelfContainedNetworkInterfaceListener;
use crate::net::web_server::{Acceptor, UpgradableListener, WebServer};
use crate::net::gateway::{BindTcp, SelfContainedNetworkInterfaceListener, UpgradableListener};
use crate::net::static_server::refresher;
use crate::net::web_server::{Acceptor, WebServer};
use crate::shutdown::Shutdown;
use crate::system::launch_metrics_task;
use crate::util::io::append_file;
@@ -147,9 +148,10 @@ pub fn main(args: impl IntoIterator<Item = OsString>) {
.build()
.expect("failed to initialize runtime");
let res = rt.block_on(async {
let mut server = WebServer::new(Acceptor::bind_upgradable(
SelfContainedNetworkInterfaceListener::bind(80),
));
let mut server = WebServer::new(
Acceptor::bind_upgradable(SelfContainedNetworkInterfaceListener::bind(BindTcp, 80)),
refresher(),
);
match inner_main(&mut server, &config).await {
Ok(a) => {
server.shutdown().await;

View File

@@ -1,4 +1,6 @@
use std::collections::BTreeMap;
use std::ffi::OsString;
use std::net::SocketAddr;
use std::time::Duration;
use clap::Parser;
@@ -7,44 +9,77 @@ use helpers::NonDetachingJoinHandle;
use rpc_toolkit::CliApp;
use tokio::signal::unix::signal;
use tracing::instrument;
use visit_rs::Visit;
use crate::context::config::ClientConfig;
use crate::context::CliContext;
use crate::net::web_server::{Acceptor, WebServer};
use crate::net::gateway::{Bind, BindTcp};
use crate::net::tls::{ChainedHandler, TlsListener};
use crate::net::web_server::{Accept, Acceptor, MetadataVisitor, WebServer};
use crate::prelude::*;
use crate::tunnel::context::{TunnelConfig, TunnelContext};
use crate::tunnel::tunnel_router;
use crate::util::logger::LOGGER;
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum WebserverListener {
Http,
Https(SocketAddr),
}
impl<V: MetadataVisitor> Visit<V> for WebserverListener {
fn visit(&self, visitor: &mut V) -> <V as visit_rs::Visitor>::Result {
visitor.visit(self)
}
}
#[instrument(skip_all)]
async fn inner_main(config: &TunnelConfig) -> Result<(), Error> {
let server = async {
let ctx = TunnelContext::init(config).await?;
let listen = ctx.listen;
let mut server = WebServer::new(Acceptor::bind([listen]).await?);
let server = WebServer::new(
Acceptor::bind_map_dyn([(WebserverListener::Http, listen)]).await?,
tunnel_router(ctx.clone()),
);
let acceptor_setter = server.acceptor_setter();
let https_db = ctx.db.clone();
let https_thread: NonDetachingJoinHandle<()> = tokio::spawn(async move {
let mut sub = setter_db.subscribe("/webserver".parse().unwrap()).await;
let mut sub = https_db.subscribe("/webserver".parse().unwrap()).await;
while sub.recv().await.is_some() {
while let Err(e) = async {
let external = setter_db.peek().await.into_webserver().de()?;
let mut bind_err = None;
setter.send_modify(|a| {
a.retain(|a, _| *a == listen || Some(*a) == external);
if let Some(external) = external {
if !a.contains_key(&external) {
match mio::net::TcpListener::bind(external) {
if let Some(addr) = https_db.peek().await.as_webserver().de()? {
acceptor_setter.send_if_modified(|a| {
let key = WebserverListener::Https(addr);
if !a.contains_key(&key) {
match (|| {
Ok::<_, Error>(TlsListener::new(
BindTcp.bind(addr)?,
BasicCertHandler(https_db.clone()),
))
})() {
Ok(l) => {
a.insert(external, TcpListener::from_std(l.into()));
}
Err(e) => bind_err = Some(e),
}
}
}
});
a.retain(|k, _| *k == WebserverListener::Http);
a.insert(key, l.into_dyn());
if let Some(e) = bind_err {
return Err(e);
true
}
Err(e) => {
tracing::error!("error adding ssl listener: {e}");
tracing::debug!("{e:?}");
false
}
}
} else {
false
}
});
} else {
acceptor_setter.send_if_modified(|a| {
let before = a.len();
a.retain(|k, _| *k == WebserverListener::Http);
a.len() != before
});
}
Ok::<_, Error>(())
@@ -58,7 +93,6 @@ async fn inner_main(config: &TunnelConfig) -> Result<(), Error> {
}
})
.into();
server.serve_tunnel(ctx.clone());
let mut shutdown_recv = ctx.shutdown.subscribe();
@@ -97,7 +131,7 @@ async fn inner_main(config: &TunnelConfig) -> Result<(), Error> {
.with_kind(crate::ErrorKind::Unknown)?;
sig_handler.wait_for_abort().await;
setter_thread.wait_for_abort().await;
https_thread.wait_for_abort().await;
Ok::<_, Error>(server)
}

View File

@@ -172,7 +172,7 @@ impl CliContext {
return Ok(secret.into())
}
Err(Error::new(
eyre!("Developer Key does not exist! Please run `start-cli init` before running this command."),
eyre!("Developer Key does not exist! Please run `start-cli init-key` before running this command."),
crate::ErrorKind::Uninitialized
))
})

View File

@@ -31,10 +31,11 @@ use crate::disk::OsPartitionInfo;
use crate::init::{check_time_is_synchronized, InitResult};
use crate::install::PKG_ARCHIVE_DIR;
use crate::lxc::LxcManager;
use crate::net::gateway::UpgradableListener;
use crate::net::net_controller::{NetController, NetService};
use crate::net::socks::DEFAULT_SOCKS_LISTEN;
use crate::net::utils::{find_eth_iface, find_wifi_iface};
use crate::net::web_server::{UpgradableListener, WebServerAcceptorSetter};
use crate::net::web_server::WebServerAcceptorSetter;
use crate::net::wifi::WpaCli;
use crate::prelude::*;
use crate::progress::{FullProgressTracker, PhaseProgressTrackerHandle};
@@ -46,7 +47,7 @@ use crate::shutdown::Shutdown;
use crate::util::io::delete_file;
use crate::util::lshw::LshwDevice;
use crate::util::sync::{SyncMutex, SyncRwLock, Watch};
use crate::{DATA_DIR, HOST_IP};
use crate::DATA_DIR;
pub struct RpcContextSeed {
is_closed: AtomicBool,

View File

@@ -20,7 +20,8 @@ use crate::context::config::ServerConfig;
use crate::context::RpcContext;
use crate::disk::OsPartitionInfo;
use crate::hostname::Hostname;
use crate::net::web_server::{UpgradableListener, WebServer, WebServerAcceptorSetter};
use crate::net::gateway::UpgradableListener;
use crate::net::web_server::{WebServer, WebServerAcceptorSetter};
use crate::prelude::*;
use crate::progress::FullProgressTracker;
use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations};

View File

@@ -31,13 +31,23 @@ lazy_static::lazy_static! {
}
pub trait DbAccess<T>: Sized {
type Key<'a>;
fn access<'a>(db: &'a Model<Self>, key: Self::Key<'_>) -> &'a Model<T>;
fn access<'a>(db: &'a Model<Self>) -> &'a Model<T>;
}
pub trait DbAccessMut<T>: Sized {
pub trait DbAccessMut<T>: DbAccess<T> {
fn access_mut<'a>(db: &'a mut Model<Self>) -> &'a mut Model<T>;
}
pub trait DbAccessByKey<T>: Sized {
type Key<'a>;
fn access_mut<'a>(db: &'a mut Model<Self>, key: Self::Key<'_>) -> &'a mut Model<T>;
fn access_by_key<'a>(db: &'a Model<Self>, key: Self::Key<'_>) -> Option<&'a Model<T>>;
}
pub trait DbAccessMutByKey<T>: DbAccessByKey<T> {
fn access_mut_by_key<'a>(
db: &'a mut Model<Self>,
key: Self::Key<'_>,
) -> Option<&'a mut Model<T>>;
}
pub fn db<C: Context>() -> ParentHandler<C> {

View File

@@ -17,6 +17,8 @@ use ts_rs::TS;
use crate::account::AccountInfo;
use crate::db::model::package::AllPackageData;
use crate::db::model::Database;
use crate::db::DbAccessByKey;
use crate::net::acme::AcmeProvider;
use crate::net::host::binding::{AddSslOptions, BindInfo, BindOptions, NetInfo};
use crate::net::host::Host;
@@ -295,6 +297,19 @@ pub enum NetworkInterfaceType {
pub struct AcmeSettings {
pub contact: Vec<String>,
}
impl DbAccessByKey<AcmeSettings> for Database {
type Key<'a> = &'a AcmeProvider;
fn access_by_key<'a>(
db: &'a Model<Self>,
key: Self::Key<'_>,
) -> Option<&'a Model<AcmeSettings>> {
db.as_public()
.as_server_info()
.as_network()
.as_acme()
.as_idx(key)
}
}
#[derive(Debug, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]

View File

@@ -1,6 +1,7 @@
use std::collections::{BTreeMap, BTreeSet};
use std::marker::PhantomData;
use std::str::FromStr;
use std::sync::Arc;
use chrono::{DateTime, Utc};
use imbl::OrdMap;
@@ -167,6 +168,21 @@ impl<T> Model<Option<T>> {
}
}
impl<T> Model<Arc<T>> {
pub fn deref(self) -> Model<T> {
use patch_db::ModelExt;
self.transmute(|a| a)
}
pub fn as_deref(&self) -> &Model<T> {
use patch_db::ModelExt;
self.transmute_ref(|a| a)
}
pub fn as_deref_mut(&mut self) -> &mut Model<T> {
use patch_db::ModelExt;
self.transmute_mut(|a| a)
}
}
pub trait Map: DeserializeOwned + Serialize {
type Key;
type Value;
@@ -196,7 +212,7 @@ where
type Key = JsonKey<A>;
type Value = B;
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
serde_json::to_string(key).with_kind(ErrorKind::Serialization)
serde_json::to_string(&key.0).with_kind(ErrorKind::Serialization)
}
}

View File

@@ -226,7 +226,7 @@ pub fn main_api<C: Context>() -> ParentHandler<C> {
util::rpc::util::<C>().with_about("Command for calculating the blake3 hash of a file"),
)
.subcommand(
"init",
"init-key",
from_fn_async(developer::init)
.no_display()
.with_about("Create developer key if it doesn't exist"),
@@ -241,6 +241,7 @@ pub fn main_api<C: Context>() -> ParentHandler<C> {
diagnostic::diagnostic::<C>()
.with_about("Commands to display logs, restart the server, etc"),
)
.subcommand("init", init::init_api::<C>())
.subcommand("setup", setup::setup::<C>())
.subcommand(
"install",

View File

@@ -6,7 +6,7 @@ use std::sync::Arc;
use async_acme::acme::{Identifier, ACME_TLS_ALPN_NAME};
use clap::builder::ValueParserFactory;
use clap::Parser;
use futures::StreamExt;
use futures::{FutureExt, StreamExt};
use imbl_value::InternedString;
use itertools::Itertools;
use models::{ErrorData, FromStrParser};
@@ -16,6 +16,7 @@ use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use tokio_rustls::rustls::crypto::CryptoProvider;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio_rustls::rustls::server::ClientHello;
use tokio_rustls::rustls::sign::CertifiedKey;
use tokio_rustls::rustls::ServerConfig;
use ts_rs::TS;
@@ -24,7 +25,7 @@ use url::Url;
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::AcmeSettings;
use crate::db::model::Database;
use crate::db::{DbAccess, DbAccessMut};
use crate::db::{DbAccess, DbAccessByKey, DbAccessMut};
use crate::net::tls::{SingleCertResolver, TlsHandler};
use crate::net::web_server::Accept;
use crate::prelude::*;
@@ -34,28 +35,28 @@ use crate::util::sync::{SyncMutex, Watch};
pub type AcmeTlsAlpnCache =
Arc<SyncMutex<BTreeMap<InternedString, Watch<Option<Arc<CertifiedKey>>>>>>;
pub struct AcmeTlsHandler<'a, M: HasModel, S: 'a> {
pub db: &'a TypedPatchDb<M>,
pub acme_cache: &'a AcmeTlsAlpnCache,
pub crypto_provider: &'a Arc<CryptoProvider>,
pub struct AcmeTlsHandler<M: HasModel, S> {
pub db: TypedPatchDb<M>,
pub acme_cache: AcmeTlsAlpnCache,
pub crypto_provider: Arc<CryptoProvider>,
pub get_provider: S,
pub in_progress: Watch<BTreeSet<BTreeSet<InternedString>>>,
}
impl<'b, M, S> AcmeTlsHandler<'b, M, S>
impl<M, S> AcmeTlsHandler<M, S>
where
for<'a> M: DbAccess<AcmeCertStore, Key<'a> = ()>
+ DbAccess<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ DbAccessMut<AcmeCertStore, Key<'a> = ()>
for<'a> M: DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ DbAccessMut<AcmeCertStore>
+ HasModel<Model = Model<M>>
+ Send
+ Sync,
S: GetAcmeProvider<'b> + Clone + 'b,
S: GetAcmeProvider + Clone,
{
pub async fn get_cert(&self, san_info: &BTreeSet<InternedString>) -> Option<CertifiedKey> {
let provider = self.get_provider.clone().get_provider(san_info).await?;
let provider = self.get_provider.get_provider(san_info).await?;
let provider = provider.as_ref();
loop {
let peek = self.db.peek().await;
let store = <M as DbAccess<AcmeCertStore>>::access(&peek, ());
let store = <M as DbAccess<AcmeCertStore>>::access(&peek);
if let Some(cert) = store
.as_certs()
.as_idx(&provider.0)
@@ -93,7 +94,7 @@ where
continue;
}
let contact = <M as DbAccess<AcmeSettings>>::access(&peek, provider)
let contact = <M as DbAccessByKey<AcmeSettings>>::access_by_key(&peek, &provider)?
.as_contact()
.de()
.log_err()?;
@@ -141,37 +142,29 @@ where
}
}
pub trait GetAcmeProvider<'a> {
fn get_provider<'b>(
self,
san_info: &'b BTreeSet<InternedString>,
) -> impl Future<Output = Option<&'a AcmeProvider>> + Send + 'b
where
Self: 'b;
pub trait GetAcmeProvider {
fn get_provider<'a, 'b: 'a>(
&'b self,
san_info: &'a BTreeSet<InternedString>,
) -> impl Future<Output = Option<impl AsRef<AcmeProvider> + Send + 'b>> + Send + 'a;
}
impl<'b, A, M, S> TlsHandler<A> for &'b AcmeTlsHandler<'b, M, S>
impl<'a, A, M, S> TlsHandler<'a, A> for Arc<AcmeTlsHandler<M, S>>
where
A: Accept,
A: Accept + 'a,
<A as Accept>::Metadata: Send + Sync,
for<'a> M: DbAccess<AcmeCertStore, Key<'a> = ()>
+ DbAccess<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ DbAccessMut<AcmeCertStore, Key<'a> = ()>
for<'m> M: DbAccessByKey<AcmeSettings, Key<'m> = &'m AcmeProvider>
+ DbAccessMut<AcmeCertStore>
+ HasModel<Model = Model<M>>
+ Send
+ Sync,
S: GetAcmeProvider<'b> + Clone + Send + Sync + 'b,
S: GetAcmeProvider + Clone + Send + Sync,
{
async fn get_config<'a>(
self,
hello: &'a tokio_rustls::rustls::server::ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig>
where
Self: 'a,
A: 'a,
<A as Accept>::Metadata: 'a,
{
async fn get_config(
&'a mut self,
hello: &'a ClientHello<'a>,
_: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig> {
let domain = hello.server_name()?;
if hello
.alpn()
@@ -238,14 +231,12 @@ impl AcmeCertStore {
}
impl DbAccess<AcmeCertStore> for Database {
type Key<'a> = ();
fn access<'a>(db: &'a Model<Self>, _: Self::Key<'_>) -> &'a Model<AcmeCertStore> {
fn access<'a>(db: &'a Model<Self>) -> &'a Model<AcmeCertStore> {
db.as_private().as_key_store().as_acme()
}
}
impl DbAccessMut<AcmeCertStore> for Database {
type Key<'a> = ();
fn access_mut<'a>(db: &'a mut Model<Self>, _: Self::Key<'_>) -> &'a mut Model<AcmeCertStore> {
fn access_mut<'a>(db: &'a mut Model<Self>) -> &'a mut Model<AcmeCertStore> {
db.as_private_mut().as_key_store_mut().as_acme_mut()
}
}
@@ -260,18 +251,14 @@ pub struct AcmeCertCache<'a, M: HasModel>(pub &'a TypedPatchDb<M>);
#[async_trait::async_trait]
impl<'a, M> async_acme::cache::AcmeCache for AcmeCertCache<'a, M>
where
for<'b> M: HasModel<Model = Model<M>>
+ DbAccess<AcmeCertStore, Key<'b> = ()>
+ DbAccessMut<AcmeCertStore, Key<'b> = ()>
+ Send
+ Sync,
M: HasModel<Model = Model<M>> + DbAccessMut<AcmeCertStore> + Send + Sync,
{
type Error = ErrorData;
async fn read_account(&self, contacts: &[&str]) -> Result<Option<Vec<u8>>, Self::Error> {
let contacts = JsonKey::new(contacts.into_iter().map(|s| (*s).to_owned()).collect_vec());
let peek = self.0.peek().await;
let Some(account) = M::access(&peek, ()).as_accounts().as_idx(&contacts) else {
let Some(account) = M::access(&peek).as_accounts().as_idx(&contacts) else {
return Ok(None);
};
Ok(Some(account.de()?.0.document.into_vec()))
@@ -285,7 +272,7 @@ where
};
self.0
.mutate(|db| {
M::access_mut(db, ())
M::access_mut(db)
.as_accounts_mut()
.insert(&contacts, &Pem::new(key))
})
@@ -312,7 +299,7 @@ where
.parse::<Url>()
.with_kind(ErrorKind::ParseUrl)?;
let peek = self.0.peek().await;
let Some(cert) = M::access(&peek, ())
let Some(cert) = M::access(&peek)
.as_certs()
.as_idx(&directory_url)
.and_then(|a| a.as_idx(&identifiers))
@@ -370,7 +357,7 @@ where
};
self.0
.mutate(|db| {
M::access_mut(db, ())
M::access_mut(db)
.as_certs_mut()
.upsert(&directory_url, || Ok(BTreeMap::new()))?
.insert(&identifiers, &cert)
@@ -443,6 +430,11 @@ impl AsRef<str> for AcmeProvider {
self.0.as_str()
}
}
impl AsRef<AcmeProvider> for AcmeProvider {
fn as_ref(&self) -> &AcmeProvider {
self
}
}
impl ValueParserFactory for AcmeProvider {
type Parser = FromStrParser<Self>;
fn value_parser() -> Self::Parser {

View File

@@ -377,6 +377,9 @@ impl PortForwardController {
}
async fn forward(interface: &str, source: SocketAddr, target: SocketAddr) -> Result<(), Error> {
if interface == START9_BRIDGE_IFACE {
return Ok(());
}
if source.is_ipv6() {
return Ok(()); // TODO: socat? ip6tables?
}
@@ -393,6 +396,9 @@ async fn forward(interface: &str, source: SocketAddr, target: SocketAddr) -> Res
}
async fn unforward(interface: &str, source: SocketAddr, target: SocketAddr) -> Result<(), Error> {
if interface == START9_BRIDGE_IFACE {
return Ok(());
}
if source.is_ipv6() {
return Ok(()); // TODO: socat? ip6tables?
}

View File

@@ -552,6 +552,7 @@ async fn watch_ip(
let managed = device_proxy.managed().await?;
if !managed {
dbg!("unmanaged", &iface);
return Ok(());
}
let dac = device_proxy.active_connection().await?;
@@ -596,10 +597,6 @@ async fn watch_ip(
_ => None,
};
if device_type == Some(NetworkInterfaceType::Loopback) {
return Ok(());
}
let name = InternedString::from(active_connection_proxy.id().await?);
let dhcp4_config = active_connection_proxy.dhcp4_config().await?;
@@ -676,7 +673,14 @@ async fn watch_ip(
.into_iter()
.map(IpNet::try_from)
.try_collect()?;
let wan_ip = if !subnets.is_empty() {
let wan_ip = if !subnets.is_empty()
&& !matches!(
device_type,
Some(
NetworkInterfaceType::Bridge
| NetworkInterfaceType::Loopback
)
) {
match get_wan_ipv4(iface.as_str()).await {
Ok(a) => a,
Err(e) => {
@@ -1535,6 +1539,17 @@ pub struct NetworkInterfaceListenerAcceptMetadata<B: Bind> {
pub inner: <B::Accept as Accept>::Metadata,
pub info: GatewayInfo,
}
impl<B: Bind> Clone for NetworkInterfaceListenerAcceptMetadata<B>
where
<B::Accept as Accept>::Metadata: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
info: self.info.clone(),
}
}
}
impl<B, V> Visit<V> for NetworkInterfaceListenerAcceptMetadata<B>
where
B: Bind,

View File

@@ -8,6 +8,7 @@ use crate::prelude::*;
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[model = "Model<Self>"]
#[serde(rename_all = "camelCase")]
pub struct KeyStore {
pub onion: OnionStore,
pub local_certs: CertStore,

View File

@@ -9,9 +9,10 @@ use ipnet::IpNet;
use models::{GatewayId, HostId, OptionExt, PackageId};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio_rustls::rustls::ClientConfig as TlsClientConfig;
use tracing::instrument;
use crate::db::model::public::{NetworkInterfaceInfo, NetworkInterfaceType};
use crate::db::model::public::NetworkInterfaceType;
use crate::db::model::Database;
use crate::error::ErrorCollection;
use crate::hostname::Hostname;
@@ -28,7 +29,7 @@ use crate::net::service_interface::{GatewayInfo, HostnameInfo, IpHostname, Onion
use crate::net::socks::SocksController;
use crate::net::tor::{OnionAddress, TorController, TorSecretKey};
use crate::net::utils::ipv6_is_local;
use crate::net::vhost::{AlpnInfo, ProxyTarget, VHostController};
use crate::net::vhost::{AlpnInfo, DynVHostTarget, ProxyTarget, VHostController};
use crate::prelude::*;
use crate::service::effects::callbacks::ServiceCallbacks;
use crate::util::serde::MaybeUtf8String;
@@ -38,6 +39,7 @@ pub struct NetController {
pub(crate) db: TypedPatchDb<Database>,
pub(super) tor: TorController,
pub(super) vhost: VHostController,
pub(super) tls_client_config: Arc<TlsClientConfig>,
pub(crate) net_iface: Arc<NetworkInterfaceController>,
pub(super) dns: DnsController,
pub(super) forward: PortForwardController,
@@ -55,10 +57,24 @@ impl NetController {
let net_iface = Arc::new(NetworkInterfaceController::new(db.clone()));
let tor = TorController::new()?;
let socks = SocksController::new(socks_listen, tor.clone())?;
let crypto_provider = Arc::new(tokio_rustls::rustls::crypto::ring::default_provider());
let tls_client_config = Arc::new(crate::net::tls::client_config(
crypto_provider.clone(),
[&*db
.peek()
.await
.as_private()
.as_key_store()
.as_local_certs()
.as_root_cert()
.de()?
.0],
)?);
Ok(Self {
db: db.clone(),
tor,
vhost: VHostController::new(db.clone(), net_iface.clone()),
vhost: VHostController::new(db.clone(), net_iface.clone(), crypto_provider),
tls_client_config,
dns: DnsController::init(db, &net_iface.watcher).await?,
forward: PortForwardController::new(net_iface.watcher.subscribe()),
net_iface,
@@ -267,7 +283,9 @@ impl NetServiceData {
filter: bind.net.clone().into_dyn(),
acme: None,
addr,
connect_ssl: connect_ssl.clone(),
connect_ssl: connect_ssl
.clone()
.map(|_| ctrl.tls_client_config.clone()),
},
);
}
@@ -288,7 +306,9 @@ impl NetServiceData {
.into_dyn(),
acme: None,
addr,
connect_ssl: connect_ssl.clone(),
connect_ssl: connect_ssl
.clone()
.map(|_| ctrl.tls_client_config.clone()),
},
); // TODO: wrap onion ssl stream directly in tor ctrl
}
@@ -315,7 +335,9 @@ impl NetServiceData {
.into_dyn(),
acme: public.acme.clone(),
addr,
connect_ssl: connect_ssl.clone(),
connect_ssl: connect_ssl
.clone()
.map(|_| ctrl.tls_client_config.clone()),
},
);
vhosts.insert(
@@ -340,7 +362,9 @@ impl NetServiceData {
.into_dyn(),
acme: public.acme.clone(),
addr,
connect_ssl: connect_ssl.clone(),
connect_ssl: connect_ssl
.clone()
.map(|_| ctrl.tls_client_config.clone()),
},
);
} else {
@@ -354,7 +378,9 @@ impl NetServiceData {
.into_dyn(),
acme: None,
addr,
connect_ssl: connect_ssl.clone(),
connect_ssl: connect_ssl
.clone()
.map(|_| ctrl.tls_client_config.clone()),
},
);
}
@@ -379,7 +405,9 @@ impl NetServiceData {
.into_dyn(),
acme: public.acme.clone(),
addr,
connect_ssl: connect_ssl.clone(),
connect_ssl: connect_ssl
.clone()
.map(|_| ctrl.tls_client_config.clone()),
},
);
} else {
@@ -393,7 +421,9 @@ impl NetServiceData {
.into_dyn(),
acme: None,
addr,
connect_ssl: connect_ssl.clone(),
connect_ssl: connect_ssl
.clone()
.map(|_| ctrl.tls_client_config.clone()),
},
);
}
@@ -427,6 +457,11 @@ impl NetServiceData {
hostname_info.remove(port).unwrap_or_default();
for (gateway_id, info) in net_ifaces
.iter()
.filter(|(_, info)| {
info.ip_info.as_ref().map_or(false, |i| {
!matches!(i.device_type, Some(NetworkInterfaceType::Bridge))
})
})
.filter(|(id, info)| bind.net.filter(id, info))
{
let gateway = GatewayInfo {
@@ -651,7 +686,10 @@ impl NetServiceData {
if let Some(prev) = prev {
prev
} else {
(target.clone(), ctrl.vhost.add(key.0, key.1, target)?)
(
target.clone(),
ctrl.vhost.add(key.0, key.1, DynVHostTarget::new(target))?,
)
},
);
} else {

View File

@@ -1,7 +1,8 @@
use std::cmp::{Ordering, min};
use std::cmp::{min, Ordering};
use std::collections::{BTreeMap, BTreeSet};
use std::net::IpAddr;
use std::path::Path;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use futures::FutureExt;
@@ -14,18 +15,32 @@ use openssl::error::ErrorStack;
use openssl::hash::MessageDigest;
use openssl::nid::Nid;
use openssl::pkey::{PKey, Private};
use openssl::x509::{X509, X509Builder, X509Extension, X509NameBuilder};
use openssl::x509::extension::{
AuthorityKeyIdentifier, BasicConstraints, KeyUsage, SubjectAlternativeName,
SubjectKeyIdentifier,
};
use openssl::x509::{X509Builder, X509NameBuilder, X509};
use openssl::*;
use patch_db::HasModel;
use serde::{Deserialize, Serialize};
use tokio_rustls::rustls::crypto::CryptoProvider;
use tokio_rustls::rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio_rustls::rustls::server::ClientHello;
use tokio_rustls::rustls::ServerConfig;
use tracing::instrument;
use visit_rs::Visit;
use crate::SOURCE_DATE;
use crate::account::AccountInfo;
use crate::db::model::Database;
use crate::db::{DbAccess, DbAccessMut};
use crate::hostname::Hostname;
use crate::init::check_time_is_synchronized;
use crate::net::gateway::GatewayInfo;
use crate::net::tls::TlsHandler;
use crate::net::web_server::{extract, Accept, ExtractVisitor, TcpMetadata};
use crate::prelude::*;
use crate::util::serde::Pem;
use crate::SOURCE_DATE;
pub fn gen_nistp256() -> Result<PKey<Private>, ErrorStack> {
PKey::from_ec_key(EcKey::generate(&*EcGroup::from_curve_name(
@@ -130,6 +145,16 @@ impl Model<CertStore> {
})
}
}
impl DbAccess<CertStore> for Database {
fn access<'a>(db: &'a Model<Self>) -> &'a Model<CertStore> {
db.as_private().as_key_store().as_local_certs()
}
}
impl DbAccessMut<CertStore> for Database {
fn access_mut<'a>(db: &'a mut Model<Self>) -> &'a mut Model<CertStore> {
db.as_private_mut().as_key_store_mut().as_local_certs_mut()
}
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct CertData {
@@ -305,22 +330,16 @@ pub fn make_root_cert(
let cfg = conf::Conf::new(conf::ConfMethod::default())?;
let ctx = builder.x509v3_context(None, Some(&cfg));
// subjectKeyIdentifier = hash
let subject_key_identifier =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_KEY_IDENTIFIER, "hash")?;
let subject_key_identifier = SubjectKeyIdentifier::new().build(&ctx)?;
// basicConstraints = critical, CA:true, pathlen:0
let basic_constraints = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::BASIC_CONSTRAINTS,
"critical,CA:true",
)?;
let basic_constraints = BasicConstraints::new().critical().ca().build()?;
// keyUsage = critical, digitalSignature, cRLSign, keyCertSign
let key_usage = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::KEY_USAGE,
"critical,digitalSignature,cRLSign,keyCertSign",
)?;
let key_usage = KeyUsage::new()
.critical()
.digital_signature()
.crl_sign()
.key_cert_sign()
.build()?;
builder.append_extension(subject_key_identifier)?;
builder.append_extension(basic_constraints)?;
builder.append_extension(key_usage)?;
@@ -355,30 +374,23 @@ pub fn make_int_cert(
let cfg = conf::Conf::new(conf::ConfMethod::default())?;
let ctx = builder.x509v3_context(Some(&signer.1), Some(&cfg));
// subjectKeyIdentifier = hash
let subject_key_identifier =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_KEY_IDENTIFIER, "hash")?;
let subject_key_identifier = SubjectKeyIdentifier::new().build(&ctx)?;
// authorityKeyIdentifier = keyid:always,issuer
let authority_key_identifier = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::AUTHORITY_KEY_IDENTIFIER,
"keyid:always,issuer",
)?;
let authority_key_identifier = AuthorityKeyIdentifier::new()
.keyid(false)
.issuer(true)
.build(&ctx)?;
// basicConstraints = critical, CA:true, pathlen:0
let basic_constraints = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::BASIC_CONSTRAINTS,
"critical,CA:true,pathlen:0",
)?;
let basic_constraints = BasicConstraints::new().critical().ca().pathlen(0).build()?;
// keyUsage = critical, digitalSignature, cRLSign, keyCertSign
let key_usage = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::KEY_USAGE,
"critical,digitalSignature,cRLSign,keyCertSign",
)?;
let key_usage = KeyUsage::new()
.critical()
.digital_signature()
.crl_sign()
.key_cert_sign()
.build()?;
builder.append_extension(subject_key_identifier)?;
builder.append_extension(authority_key_identifier)?;
builder.append_extension(basic_constraints)?;
@@ -428,6 +440,16 @@ impl SANInfo {
}
Self { dns, ips }
}
pub fn x509_extension(&self) -> SubjectAlternativeName {
let mut san = SubjectAlternativeName::new();
for h in &self.dns {
san.dns(&h.as_str());
}
for ip in &self.ips {
san.ip(&ip.to_string());
}
san
}
}
impl std::fmt::Display for SANInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -490,28 +512,20 @@ pub fn make_leaf_cert(
// Extensions
let cfg = conf::Conf::new(conf::ConfMethod::default())?;
let ctx = builder.x509v3_context(Some(&signer.1), Some(&cfg));
// subjectKeyIdentifier = hash
let subject_key_identifier =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_KEY_IDENTIFIER, "hash")?;
// authorityKeyIdentifier = keyid:always,issuer
let authority_key_identifier = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::AUTHORITY_KEY_IDENTIFIER,
"keyid,issuer:always",
)?;
let basic_constraints =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::BASIC_CONSTRAINTS, "CA:FALSE")?;
let key_usage = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::KEY_USAGE,
"critical,digitalSignature,keyEncipherment",
)?;
let san_string = applicant.1.to_string();
let subject_alt_name =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_ALT_NAME, &san_string)?;
let subject_key_identifier = SubjectKeyIdentifier::new().build(&ctx)?;
let authority_key_identifier = AuthorityKeyIdentifier::new()
.keyid(false)
.issuer(true)
.build(&ctx)?;
let subject_alt_name = applicant.1.x509_extension().build(&ctx)?;
let basic_constraints = BasicConstraints::new().build()?;
let key_usage = KeyUsage::new()
.critical()
.digital_signature()
.key_encipherment()
.build()?;
builder.append_extension(subject_key_identifier)?;
builder.append_extension(authority_key_identifier)?;
builder.append_extension(subject_alt_name)?;
@@ -561,28 +575,20 @@ pub fn make_self_signed(applicant: (&PKey<Private>, &SANInfo)) -> Result<X509, E
// Extensions
let cfg = conf::Conf::new(conf::ConfMethod::default())?;
let ctx = builder.x509v3_context(None, Some(&cfg));
// subjectKeyIdentifier = hash
let subject_key_identifier =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_KEY_IDENTIFIER, "hash")?;
// authorityKeyIdentifier = keyid:always,issuer
let authority_key_identifier = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::AUTHORITY_KEY_IDENTIFIER,
"keyid,issuer:always",
)?;
let basic_constraints =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::BASIC_CONSTRAINTS, "CA:FALSE")?;
let key_usage = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::KEY_USAGE,
"critical,digitalSignature,keyEncipherment",
)?;
let san_string = applicant.1.to_string();
let subject_alt_name =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_ALT_NAME, &san_string)?;
let subject_key_identifier = SubjectKeyIdentifier::new().build(&ctx)?;
let authority_key_identifier = AuthorityKeyIdentifier::new()
.keyid(false)
.issuer(true)
.build(&ctx)?;
let subject_alt_name = applicant.1.x509_extension().build(&ctx)?;
let basic_constraints = BasicConstraints::new().build()?;
let key_usage = KeyUsage::new()
.critical()
.digital_signature()
.key_encipherment()
.build()?;
builder.append_extension(subject_key_identifier)?;
builder.append_extension(authority_key_identifier)?;
builder.append_extension(subject_alt_name)?;
@@ -594,3 +600,98 @@ pub fn make_self_signed(applicant: (&PKey<Private>, &SANInfo)) -> Result<X509, E
let cert = builder.build();
Ok(cert)
}
pub struct RootCaTlsHandler<M: HasModel> {
pub db: TypedPatchDb<M>,
pub crypto_provider: Arc<CryptoProvider>,
}
impl<M: HasModel> Clone for RootCaTlsHandler<M> {
fn clone(&self) -> Self {
Self {
db: self.db.clone(),
crypto_provider: self.crypto_provider.clone(),
}
}
}
impl<'a, A, M> TlsHandler<'a, A> for RootCaTlsHandler<M>
where
A: Accept + 'a,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
+ Clone
+ Send
+ Sync
+ 'static,
M: HasModel<Model = Model<M>> + DbAccessMut<CertStore> + Send + Sync,
{
async fn get_config(
&mut self,
hello: &ClientHello<'_>,
metadata: &<A as Accept>::Metadata,
) -> Option<ServerConfig> {
let hostnames: BTreeSet<InternedString> = hello
.server_name()
.map(InternedString::from)
.into_iter()
.chain(
extract::<TcpMetadata, _>(metadata)
.map(|m| m.local_addr.ip())
.as_ref()
.map(InternedString::from_display),
)
.chain(
extract::<GatewayInfo, _>(metadata)
.and_then(|i| i.info.ip_info)
.and_then(|i| i.wan_ip)
.as_ref()
.map(InternedString::from_display),
)
.collect();
let cert = self
.db
.mutate(|db| M::access_mut(db).cert_for(&hostnames))
.await
.result
.log_err()?;
let cfg = ServerConfig::builder_with_provider(self.crypto_provider.clone())
.with_safe_default_protocol_versions()
.log_err()?
.with_no_client_auth();
if hello
.signature_schemes()
.contains(&tokio_rustls::rustls::SignatureScheme::ED25519)
{
cfg.with_single_cert(
cert.fullchain_ed25519()
.into_iter()
.map(|c| {
Ok(tokio_rustls::rustls::pki_types::CertificateDer::from(
c.to_der()?,
))
})
.collect::<Result<_, Error>>()
.log_err()?,
PrivateKeyDer::from(PrivatePkcs8KeyDer::from(
cert.leaf.keys.ed25519.private_key_to_pkcs8().log_err()?,
)),
)
} else {
cfg.with_single_cert(
cert.fullchain_nistp256()
.into_iter()
.map(|c| {
Ok(tokio_rustls::rustls::pki_types::CertificateDer::from(
c.to_der()?,
))
})
.collect::<Result<_, Error>>()
.log_err()?,
PrivateKeyDer::from(PrivatePkcs8KeyDer::from(
cert.leaf.keys.nistp256.private_key_to_pkcs8().log_err()?,
)),
)
}
.log_err()
}
}

View File

@@ -4,10 +4,13 @@ use std::task::Poll;
use futures::future::BoxFuture;
use futures::FutureExt;
use imbl_value::InternedString;
use openssl::x509::X509Ref;
use tokio::io::AsyncWriteExt;
use tokio_rustls::rustls::crypto::CryptoProvider;
use tokio_rustls::rustls::pki_types::CertificateDer;
use tokio_rustls::rustls::server::{Acceptor, ClientHello, ResolvesServerCert};
use tokio_rustls::rustls::sign::CertifiedKey;
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::rustls::{ClientConfig, RootCertStore, ServerConfig};
use tokio_rustls::LazyConfigAcceptor;
use visit_rs::{Visit, VisitFields};
@@ -38,35 +41,28 @@ impl<V: MetadataVisitor> Visit<V> for TlsHandshakeInfo {
}
}
pub trait TlsHandler<A: Accept> {
fn get_config<'a>(
self,
pub trait TlsHandler<'a, A: Accept> {
fn get_config(
&'a mut self,
hello: &'a ClientHello<'a>,
metadata: &'a A::Metadata,
) -> impl Future<Output = Option<ServerConfig>> + Send + 'a
where
Self: 'a,
A: 'a,
A::Metadata: 'a;
) -> impl Future<Output = Option<ServerConfig>> + Send + 'a;
}
#[derive(Clone)]
pub struct ChainedHandler<H0, H1>(pub H0, pub H1);
impl<A, H0, H1> TlsHandler<A> for ChainedHandler<H0, H1>
impl<'a, A, H0, H1> TlsHandler<'a, A> for ChainedHandler<H0, H1>
where
A: Accept,
A: Accept + 'a,
<A as Accept>::Metadata: Send + Sync,
H0: TlsHandler<A> + Send,
H1: TlsHandler<A> + Send,
H0: TlsHandler<'a, A> + Send,
H1: TlsHandler<'a, A> + Send,
{
async fn get_config<'a>(
self,
async fn get_config(
&'a mut self,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig>
where
Self: 'a,
{
) -> Option<ServerConfig> {
if let Some(config) = self.0.get_config(hello, metadata).await {
return Some(config);
}
@@ -74,6 +70,7 @@ where
}
}
#[derive(Clone)]
pub struct TlsHandlerWrapper<I, W> {
pub inner: I,
pub wrapper: W,
@@ -81,7 +78,7 @@ pub struct TlsHandlerWrapper<I, W> {
pub trait WrapTlsHandler<A: Accept> {
fn wrap<'a>(
self,
&'a mut self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
@@ -90,23 +87,18 @@ pub trait WrapTlsHandler<A: Accept> {
Self: 'a;
}
impl<A, I, W> TlsHandler<A> for TlsHandlerWrapper<I, W>
impl<'a, A, I, W> TlsHandler<'a, A> for TlsHandlerWrapper<I, W>
where
A: Accept,
A: Accept + 'a,
<A as Accept>::Metadata: Send + Sync,
I: TlsHandler<A> + Send,
I: TlsHandler<'a, A> + Send,
W: WrapTlsHandler<A> + Send,
{
async fn get_config<'a>(
self,
async fn get_config(
&'a mut self,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig>
where
Self: 'a,
A: 'a,
<A as Accept>::Metadata: 'a,
{
) -> Option<ServerConfig> {
let prev = self.inner.get_config(hello, metadata).await?;
self.wrapper.wrap(prev, hello, metadata).await
}
@@ -120,13 +112,20 @@ impl ResolvesServerCert for SingleCertResolver {
}
}
pub struct TlsListener<A: Accept, H: TlsHandler<A>> {
pub struct TlsListener<A: Accept, H: for<'a> TlsHandler<'a, A>> {
pub accept: A,
pub tls_handler: H,
in_progress:
Vec<BoxFuture<'static, Result<Option<(TlsMetadata<A::Metadata>, AcceptStream)>, Error>>>,
in_progress: Vec<
BoxFuture<
'static,
(
H,
Result<Option<(TlsMetadata<A::Metadata>, AcceptStream)>, Error>,
),
>,
>,
}
impl<A: Accept, H: TlsHandler<A>> TlsListener<A, H> {
impl<A: Accept, H: for<'a> TlsHandler<'a, A>> TlsListener<A, H> {
pub fn new(accept: A, cert_handler: H) -> Self {
Self {
accept,
@@ -137,93 +136,111 @@ impl<A: Accept, H: TlsHandler<A>> TlsListener<A, H> {
}
impl<A, H> Accept for TlsListener<A, H>
where
A: Accept,
A: Accept + 'static,
A::Metadata: Send + 'static,
H: TlsHandler<A> + Clone + Send + 'static,
for<'a> H: TlsHandler<'a, A> + Clone + Send + 'static,
{
type Metadata = TlsMetadata<A::Metadata>;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
if let Some((idx, res)) = self
.in_progress
.iter_mut()
.enumerate()
.find_map(|(idx, fut)| match fut.poll_unpin(cx) {
Poll::Ready(a) => Some((idx, a)),
Poll::Pending => None,
})
{
self.in_progress.swap_remove(idx);
if let Some(res) = res.transpose() {
return Poll::Ready(res);
}
}
if let Poll::Ready((metadata, stream)) = self.accept.poll_accept(cx)? {
let tls_handler = self.tls_handler.clone();
self.in_progress.push(
async move {
let mut acceptor =
LazyConfigAcceptor::new(Acceptor::default(), BackTrackingIO::new(stream));
let mut mid: tokio_rustls::StartHandshake<BackTrackingIO<AcceptStream>> =
match (&mut acceptor).await {
Ok(a) => a,
Err(e) => {
let mut stream = acceptor.take_io().or_not_found("acceptor io")?;
let (_, buf) = stream.rewind();
if std::str::from_utf8(buf)
.ok()
.and_then(|buf| {
buf.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty())
.next()
})
.map_or(false, |buf| {
regex::Regex::new("[A-Z]+ (.+) HTTP/1")
.unwrap()
.is_match(buf)
})
{
handle_http_on_https(stream).await.log_err();
return Ok(None);
} else {
return Err(e).with_kind(ErrorKind::Network);
}
}
};
let hello = mid.client_hello();
if let Some(cfg) = tls_handler.get_config(&hello, &metadata).await {
let metadata = TlsMetadata {
inner: metadata,
tls_info: TlsHandshakeInfo {
sni: hello.server_name().map(InternedString::intern),
alpn: hello
.alpn()
.into_iter()
.flatten()
.map(|a| MaybeUtf8String(a.to_vec()))
.collect(),
},
};
let buffered = mid.io.stop_buffering();
mid.io
.write_all(&buffered)
.await
.with_kind(ErrorKind::Network)?;
return Ok(Some((
metadata,
Box::pin(mid.into_stream(Arc::new(cfg)).await?) as AcceptStream,
)));
}
Ok(None)
loop {
if let Some((idx, (handler, res))) =
self.in_progress
.iter_mut()
.enumerate()
.find_map(|(idx, fut)| match fut.poll_unpin(cx) {
Poll::Ready(a) => Some((idx, a)),
Poll::Pending => None,
})
{
drop(self.in_progress.swap_remove(idx));
if let Some(res) = res.transpose() {
self.tls_handler = handler;
return Poll::Ready(res);
}
.boxed(),
);
continue;
}
if let Poll::Ready((metadata, stream)) = self.accept.poll_accept(cx)? {
crate::dbg!("ACCEPTED");
let mut tls_handler = self.tls_handler.clone();
self.in_progress.push(
async move {
let res = async {
let mut acceptor = LazyConfigAcceptor::new(
Acceptor::default(),
BackTrackingIO::new(stream),
);
let mut mid: tokio_rustls::StartHandshake<
BackTrackingIO<AcceptStream>,
> = match (&mut acceptor).await {
Ok(a) => a,
Err(e) => {
let mut stream =
acceptor.take_io().or_not_found("acceptor io")?;
let (_, buf) = stream.rewind();
if std::str::from_utf8(buf)
.ok()
.and_then(|buf| {
buf.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty())
.next()
})
.map_or(false, |buf| {
regex::Regex::new("[A-Z]+ (.+) HTTP/1")
.unwrap()
.is_match(buf)
})
{
handle_http_on_https(stream).await.log_err();
return Ok(None);
} else {
return Err(e).with_kind(ErrorKind::Network);
}
}
};
let hello = mid.client_hello();
crate::dbg!("getting config");
if let Some(cfg) = tls_handler.get_config(&hello, &metadata).await {
crate::dbg!("config gotten");
let metadata = TlsMetadata {
inner: metadata,
tls_info: TlsHandshakeInfo {
sni: hello.server_name().map(InternedString::intern),
alpn: hello
.alpn()
.into_iter()
.flatten()
.map(|a| MaybeUtf8String(a.to_vec()))
.collect(),
},
};
let buffered = mid.io.stop_buffering();
mid.io
.write_all(&buffered)
.await
.with_kind(ErrorKind::Network)?;
return Ok(Some((
metadata,
Box::pin(mid.into_stream(Arc::new(cfg)).await?) as AcceptStream,
)));
}
crate::dbg!("no config");
Ok(None)
}
.await;
(tls_handler, res)
}
.boxed(),
);
continue;
}
break;
}
Poll::Pending
@@ -280,3 +297,20 @@ async fn handle_http_on_https(stream: impl ReadWriter + Unpin + 'static) -> Resu
.await
.map_err(|e| Error::new(color_eyre::eyre::Report::msg(e), ErrorKind::Network))
}
pub fn client_config<'a, I: IntoIterator<Item = &'a X509Ref>>(
crypto_provider: Arc<CryptoProvider>,
root_certs: I,
) -> Result<ClientConfig, Error> {
let mut certs = RootCertStore::empty();
for cert in root_certs {
certs
.add(CertificateDer::from_slice(&cert.to_der()?))
.with_kind(ErrorKind::OpenSsl)?;
}
Ok(ClientConfig::builder_with_provider(crypto_provider.clone())
.with_safe_default_protocol_versions()
.with_kind(ErrorKind::OpenSsl)?
.with_root_certificates(certs)
.with_no_client_auth())
}

View File

@@ -125,7 +125,7 @@ pub async fn remove_tunnel(
return Ok(());
};
if existing.as_device_type().de()? != Some(NetworkInterfaceType::Wireguard) {
if existing.as_deref().as_device_type().de()? != Some(NetworkInterfaceType::Wireguard) {
return Err(Error::new(
eyre!("network interface {id} is not a proxy"),
ErrorKind::InvalidRequest,

View File

@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV6};
use std::path::Path;
@@ -5,7 +6,6 @@ use async_stream::try_stream;
use color_eyre::eyre::eyre;
use futures::stream::BoxStream;
use futures::{StreamExt, TryStreamExt};
use imbl::OrdMap;
use imbl_value::InternedString;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use models::GatewayId;
@@ -13,13 +13,11 @@ use nix::net::if_::if_nametoindex;
use tokio::net::{TcpListener, TcpStream};
use tokio::process::Command;
use crate::db::model::public::{NetworkInterfaceInfo, NetworkInterfaceType};
use crate::db::model::public::{IpInfo, NetworkInterfaceType};
use crate::prelude::*;
use crate::util::collections::OrdMapIterMut;
use crate::util::Invoke;
pub async fn load_network_interface_info() -> Result<OrdMap<GatewayId, NetworkInterfaceInfo>, Error>
{
pub async fn load_ip_info() -> Result<BTreeMap<GatewayId, IpInfo>, Error> {
let output = String::from_utf8(
Command::new("ip")
.arg("-o")
@@ -36,14 +34,14 @@ pub async fn load_network_interface_info() -> Result<OrdMap<GatewayId, NetworkIn
)
};
let mut res = OrdMap::<GatewayId, NetworkInterfaceInfo>::new();
let mut res = BTreeMap::<GatewayId, IpInfo>::new();
for line in output.lines() {
let split = line.split_ascii_whitespace().collect::<Vec<_>>();
let iface = GatewayId::from(InternedString::from(*split.get(1).ok_or_else(&err_fn)?));
let subnet: IpNet = split.get(3).ok_or_else(&err_fn)?.parse()?;
let info = res.entry(iface).or_default();
let ip_info = info.ip_info.get_or_insert_default();
let ip_info = res.entry(iface.clone()).or_default();
ip_info.name = iface.into();
ip_info.scope_id = split
.get(0)
.ok_or_else(&err_fn)?
@@ -53,8 +51,7 @@ pub async fn load_network_interface_info() -> Result<OrdMap<GatewayId, NetworkIn
ip_info.subnets.insert(subnet);
}
for (id, info) in OrdMapIterMut::from(&mut res) {
let ip_info = info.ip_info.get_or_insert_default();
for (id, ip_info) in res.iter_mut() {
ip_info.device_type = probe_iface_type(id.as_str()).await;
}

View File

@@ -2,50 +2,45 @@ use std::any::Any;
use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, Weak};
use std::task::{ready, Poll};
use async_acme::acme::{Identifier, ACME_TLS_ALPN_NAME};
use axum::body::Body;
use axum::extract::Request;
use axum::response::Response;
use async_acme::acme::ACME_TLS_ALPN_NAME;
use color_eyre::eyre::eyre;
use futures::future::BoxFuture;
use futures::FutureExt;
use helpers::NonDetachingJoinHandle;
use http::{Extensions, Uri};
use imbl::OrdMap;
use imbl_value::{InOMap, InternedString};
use itertools::Itertools;
use models::{GatewayId, ResultExt};
use models::ResultExt;
use rpc_toolkit::{from_fn, Context, HandlerArgs, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
use tokio::sync::watch;
use tokio_rustls::rustls::crypto::CryptoProvider;
use tokio_rustls::rustls::pki_types::{
CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName,
};
use tokio_rustls::rustls::server::{Acceptor, ClientHello};
use tokio_rustls::rustls::{ClientConfig, RootCertStore, ServerConfig};
use tokio_rustls::{LazyConfigAcceptor, TlsConnector};
use tokio_stream::wrappers::WatchStream;
use tokio_stream::StreamExt;
use tokio_rustls::rustls::pki_types::ServerName;
use tokio_rustls::rustls::server::ClientHello;
use tokio_rustls::rustls::{ClientConfig, ServerConfig};
use tokio_rustls::TlsConnector;
use tracing::instrument;
use ts_rs::TS;
use visit_rs::Visit;
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::NetworkInterfaceInfo;
use crate::db::model::public::AcmeSettings;
use crate::db::model::Database;
use crate::net::acme::{AcmeCertCache, AcmeTlsAlpnCache, AcmeTlsHandler};
use crate::net::gateway::{
AnyFilter, BindTcp, DynInterfaceFilter, GatewayInfo, InterfaceFilter, NetworkInterfaceController, NetworkInterfaceListener
use crate::db::{DbAccessByKey, DbAccessMut};
use crate::net::acme::{
AcmeCertStore, AcmeProvider, AcmeTlsAlpnCache, AcmeTlsHandler, GetAcmeProvider,
};
use crate::net::static_server::server_error;
use crate::net::tls::{ChainedHandler, TlsHandler, TlsListener, WrapTlsHandler};
use crate::net::web_server::{Accept, AcceptStream, extract};
use crate::net::gateway::{
AnyFilter, BindTcp, DynInterfaceFilter, GatewayInfo, InterfaceFilter,
NetworkInterfaceController, NetworkInterfaceListener,
};
use crate::net::ssl::{CertStore, RootCaTlsHandler};
use crate::net::tls::{
ChainedHandler, TlsHandlerWrapper, TlsListener, TlsMetadata, WrapTlsHandler,
};
use crate::net::web_server::{extract, Accept, AcceptStream, ExtractVisitor, TcpMetadata};
use crate::prelude::*;
use crate::util::collections::EqSet;
use crate::util::io::BackTrackingIO;
use crate::util::serde::{display_serializable, HandlerExtSerde, MaybeUtf8String};
use crate::util::sync::{SyncMutex, Watch};
@@ -63,7 +58,7 @@ pub fn vhost_api<C: Context>() -> ParentHandler<C> {
}
let mut table = Table::new();
table.add_row(row![bc => "FROM", "TO", "GATEWAYS", "CONNECT SSL", "ACTIVE"]);
table.add_row(row![bc => "FROM", "TO", "ACTIVE"]);
for (external, targets) in res {
for (host, targets) in targets {
@@ -74,9 +69,7 @@ pub fn vhost_api<C: Context>() -> ParentHandler<C> {
host.as_ref().map(|s| &**s).unwrap_or("*"),
external.0
),
target.addr,
target.gateways.iter().join(", "),
target.connect_ssl.is_ok(),
target,
idx == 0
]);
}
@@ -101,11 +94,15 @@ pub struct VHostController {
servers: SyncMutex<BTreeMap<u16, VHostServer<NetworkInterfaceListener>>>,
}
impl VHostController {
pub fn new(db: TypedPatchDb<Database>, interfaces: Arc<NetworkInterfaceController>) -> Self {
pub fn new(
db: TypedPatchDb<Database>,
interfaces: Arc<NetworkInterfaceController>,
crypto_provider: Arc<CryptoProvider>,
) -> Self {
Self {
db,
interfaces,
crypto_provider: Arc::new(tokio_rustls::rustls::crypto::ring::default_provider()),
crypto_provider,
acme_cache: Arc::new(SyncMutex::new(BTreeMap::new())),
servers: SyncMutex::new(BTreeMap::new()),
}
@@ -115,7 +112,7 @@ impl VHostController {
&self,
hostname: Option<InternedString>,
external: u16,
target: impl VHostTarget<NetworkInterfaceListener>,
target: DynVHostTarget<NetworkInterfaceListener>,
) -> Result<Arc<()>, Error> {
self.servers.mutate(|writable| {
let server = if let Some(server) = writable.remove(&external) {
@@ -136,27 +133,26 @@ impl VHostController {
pub fn dump_table(
&self,
) -> BTreeMap<JsonKey<u16>, BTreeMap<JsonKey<Option<InternedString>>, EqSet<ShowTargetInfo>>>
{
) -> BTreeMap<JsonKey<u16>, BTreeMap<JsonKey<Option<InternedString>>, EqSet<String>>> {
let ip_info = self.interfaces.watcher.ip_info();
self.servers.peek(|s| {
s.iter()
.map(|(k, v)| {
(
JsonKey::new(*k),
v.mapping
.borrow()
.iter()
.map(|(k, v)| {
(
JsonKey::new(k.clone()),
v.iter()
.filter(|(_, v)| v.strong_count() > 0)
.map(|(k, _)| ShowTargetInfo::new(k.clone(), &ip_info))
.collect(),
)
})
.collect(),
v.mapping.peek(|m| {
m.iter()
.map(|(k, v)| {
(
JsonKey::new(k.clone()),
v.iter()
.filter(|(_, v)| v.strong_count() > 0)
.map(|(k, _)| format!("{k:?}"))
.collect(),
)
})
.collect()
}),
)
})
.collect()
@@ -179,8 +175,11 @@ impl VHostController {
pub trait VHostTarget<A: Accept>: std::fmt::Debug + Eq {
type PreprocessRes: Send + 'static;
#[allow(unused_variables)]
fn skip(&self, metadata: &<A as Accept>::Metadata) -> bool {
false
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool {
true
}
fn acme(&self) -> Option<&AcmeProvider> {
None
}
fn preprocess<'a>(
&'a self,
@@ -192,7 +191,8 @@ pub trait VHostTarget<A: Accept>: std::fmt::Debug + Eq {
}
pub trait DynVHostTargetT<A: Accept>: std::fmt::Debug + Any {
fn skip(&self, metadata: &<A as Accept>::Metadata) -> bool;
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool;
fn acme(&self) -> Option<&AcmeProvider>;
fn preprocess<'a>(
&'a self,
prev: ServerConfig,
@@ -203,8 +203,11 @@ pub trait DynVHostTargetT<A: Accept>: std::fmt::Debug + Any {
fn eq(&self, other: &dyn DynVHostTargetT<A>) -> bool;
}
impl<A: Accept, T: VHostTarget<A> + 'static> DynVHostTargetT<A> for T {
fn skip(&self, metadata: &<A as Accept>::Metadata) -> bool {
VHostTarget::skip(self, metadata)
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool {
VHostTarget::filter(self, metadata)
}
fn acme(&self) -> Option<&AcmeProvider> {
VHostTarget::acme(self)
}
fn preprocess<'a>(
&'a self,
@@ -226,12 +229,22 @@ impl<A: Accept, T: VHostTarget<A> + 'static> DynVHostTargetT<A> for T {
}
}
struct DynVHostTarget<A: Accept>(Arc<dyn DynVHostTargetT<A> + Send + Sync>);
pub struct DynVHostTarget<A: Accept>(Arc<dyn DynVHostTargetT<A> + Send + Sync>);
impl<A: Accept> DynVHostTarget<A> {
pub fn new<T: VHostTarget<A> + Send + Sync + 'static>(target: T) -> Self {
Self(Arc::new(target))
}
}
impl<A: Accept> Clone for DynVHostTarget<A> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<A: Accept> std::fmt::Debug for DynVHostTarget<A> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl<A: Accept + 'static> PartialEq for DynVHostTarget<A> {
fn eq(&self, other: &Self) -> bool {
self.0.eq(&*other.0)
@@ -259,6 +272,7 @@ impl<A: Accept + 'static> Preprocessed<A> {
#[derive(Debug, Clone)]
pub struct ProxyTarget {
pub filter: DynInterfaceFilter,
pub acme: Option<AcmeProvider>,
pub addr: SocketAddr,
pub connect_ssl: Result<Arc<ClientConfig>, AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn
}
@@ -266,7 +280,8 @@ impl PartialEq for ProxyTarget {
fn eq(&self, other: &Self) -> bool {
self.filter == other.filter
&& self.addr == other.addr
&& self.connect_ssl.as_ref().err() == other.connect_ssl.as_ref().err()
&& self.connect_ssl.as_ref().map(Arc::as_ptr)
== other.connect_ssl.as_ref().map(Arc::as_ptr)
}
}
impl Eq for ProxyTarget {}
@@ -274,11 +289,16 @@ impl Eq for ProxyTarget {}
impl<A> VHostTarget<A> for ProxyTarget
where
A: Accept + 'static,
<A as Accept>::Metadata: Send + Sync,
<A as Accept>::Metadata: Visit<ExtractVisitor<GatewayInfo>> + Clone + Send + Sync,
{
type PreprocessRes = AcceptStream;
fn skip(&self, metadata: &<A as Accept>::Metadata) -> bool {
let info = extract::<GatewayInfo,_>(metadata)
fn filter(&self, metadata: &<A as Accept>::Metadata) -> bool {
let info = extract::<GatewayInfo, _>(metadata);
info.as_ref()
.map_or(true, |i| self.filter.filter(&i.id, &i.info))
}
fn acme(&self) -> Option<&AcmeProvider> {
self.acme.as_ref()
}
async fn preprocess<'a>(
&'a self,
@@ -345,15 +365,61 @@ impl Default for AlpnInfo {
type Mapping<A: Accept> = BTreeMap<Option<InternedString>, InOMap<DynVHostTarget<A>, Weak<()>>>;
pub struct VHostConnector<'a, A: Accept + 'static>(&'a Watch<Mapping<A>>, Option<Preprocessed<A>>);
pub struct GetVHostAcmeProvider<A: Accept + 'static>(pub Watch<Mapping<A>>);
impl<A: Accept + 'static> Clone for GetVHostAcmeProvider<A> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<A: Accept + 'static> GetAcmeProvider for GetVHostAcmeProvider<A> {
async fn get_provider<'a, 'b: 'a>(
&'b self,
san_info: &'a BTreeSet<InternedString>,
) -> Option<impl AsRef<AcmeProvider> + Send + 'b> {
self.0.peek(|m| -> Option<AcmeProvider> {
san_info
.iter()
.fold(Some::<Option<&AcmeProvider>>(None), |acc, x| {
let acc = acc?;
if x.parse::<IpAddr>().is_ok() {
return Some(acc);
}
let (t, _) = m
.get(&Some(x.clone()))?
.iter()
.find(|(_, rc)| rc.strong_count() > 0)?;
let acme = t.0.acme()?;
Some(if let Some(acc) = acc {
if acme == acc {
// all must match
Some(acme)
} else {
None
}
} else {
Some(acme)
})
})
.flatten()
.cloned()
})
}
}
impl<'b, A> WrapTlsHandler<A> for &'b mut VHostConnector<'b, A>
pub struct VHostConnector<A: Accept + 'static>(Watch<Mapping<A>>, Option<Preprocessed<A>>);
impl<A: Accept + 'static> Clone for VHostConnector<A> {
fn clone(&self) -> Self {
Self(self.0.clone(), None)
}
}
impl<A> WrapTlsHandler<A> for VHostConnector<A>
where
A: Accept + 'static,
<A as Accept>::Metadata: Send + Sync,
<A as Accept>::Metadata: Visit<ExtractVisitor<GatewayInfo>> + Send + Sync,
{
async fn wrap<'a>(
self,
&'a mut self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
@@ -375,7 +441,7 @@ where
.into_iter()
.flatten()
.filter(|(_, rc)| rc.strong_count() > 0)
.find(|(t, _)| !t.skip(metadata))
.find(|(t, _)| t.0.filter(metadata))
.map(|(e, _)| e.clone())
})?;
@@ -387,6 +453,97 @@ where
}
}
struct VHostListener<M, A>(
TlsListener<
A,
TlsHandlerWrapper<
ChainedHandler<Arc<AcmeTlsHandler<M, GetVHostAcmeProvider<A>>>, RootCaTlsHandler<M>>,
VHostConnector<A>,
>,
>,
)
where
for<'a> M: HasModel<Model = Model<M>>
+ DbAccessMut<CertStore>
+ DbAccessMut<AcmeCertStore>
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ Send
+ Sync,
A: Accept + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
+ Clone
+ Send
+ Sync
+ 'static;
struct VHostListenerMetadata<A: Accept> {
inner: TlsMetadata<A::Metadata>,
preprocessed: Preprocessed<A>,
}
impl<M, A> Accept for VHostListener<M, A>
where
for<'a> M: HasModel<Model = Model<M>>
+ DbAccessMut<CertStore>
+ DbAccessMut<AcmeCertStore>
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ Send
+ Sync
+ 'static,
A: Accept + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
+ Clone
+ Send
+ Sync
+ 'static,
{
type Metadata = VHostListenerMetadata<A>;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
let (metadata, stream) = ready!(self.0.poll_accept(cx)?);
let preprocessed = self.0.tls_handler.wrapper.1.take();
Poll::Ready(Ok((
VHostListenerMetadata {
inner: metadata,
preprocessed: preprocessed.ok_or_else(|| {
Error::new(
eyre!("tlslistener yielded but preprocessed isn't set"),
ErrorKind::Incoherent,
)
})?,
},
stream,
)))
}
}
impl<M, A> VHostListener<M, A>
where
for<'a> M: HasModel<Model = Model<M>>
+ DbAccessMut<CertStore>
+ DbAccessMut<AcmeCertStore>
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ Send
+ Sync
+ 'static,
A: Accept + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
+ Clone
+ Send
+ Sync
+ 'static,
{
async fn handle_next(&mut self) -> Result<(), Error> {
let (metadata, stream) = futures::future::poll_fn(|cx| self.poll_accept(cx)).await?;
metadata.preprocessed.finish(stream);
Ok(())
}
}
struct VHostServer<A: Accept + 'static> {
mapping: Watch<Mapping<A>>,
_thread: NonDetachingJoinHandle<()>,
@@ -407,472 +564,69 @@ impl<'a> From<&'a BTreeMap<Option<InternedString>, BTreeMap<ProxyTarget, Weak<()
}
}
impl VHostServer {
async fn accept(
listener: &mut NetworkInterfaceListener,
mut mapping: watch::Receiver<Mapping>,
db: TypedPatchDb<Database>,
acme_tls_alpn_cache: AcmeTlsAlpnCache,
crypto_provider: Arc<CryptoProvider>,
) -> Result<(), Error> {
let accepted;
loop {
let any_filter = AnyFilter::from(&*mapping.borrow());
let changed_filter = mapping
.wait_for(|m| any_filter != AnyFilter::from(m))
.boxed();
tokio::select! {
a = listener.accept(&any_filter) => {
accepted = a?;
break;
}
_ = changed_filter => {
tracing::debug!("port {} filter changed", listener.port());
}
}
}
let check = listener.check_filter();
tokio::spawn(async move {
let bind = accepted.bind;
if let Err(e) = Self::handle_stream(
accepted,
check,
mapping,
db,
acme_tls_alpn_cache,
crypto_provider,
)
.await
{
tracing::error!("Error in VHostController on {bind}: {e}");
tracing::debug!("{e:?}")
}
});
Ok(())
}
async fn handle_stream(
Accepted {
stream,
wan_ip,
bind,
..
}: Accepted,
check_filter: impl FnOnce(SocketAddr, &DynInterfaceFilter) -> bool,
mapping: watch::Receiver<Mapping>,
db: TypedPatchDb<Database>,
acme_tls_alpn_cache: AcmeTlsAlpnCache,
crypto_provider: Arc<CryptoProvider>,
) -> Result<(), Error> {
let mut stream = BackTrackingIO::new(stream);
let mid: tokio_rustls::StartHandshake<&mut BackTrackingIO<TcpStream>> =
match LazyConfigAcceptor::new(Acceptor::default(), &mut stream).await {
Ok(a) => a,
Err(e) => {
let (_, buf) = stream.rewind();
if std::str::from_utf8(buf)
.ok()
.and_then(|buf| {
buf.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty())
.next()
})
.map_or(false, |buf| {
regex::Regex::new("[A-Z]+ (.+) HTTP/1")
.unwrap()
.is_match(buf)
})
{
return hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(
hyper_util::rt::TokioIo::new(stream),
hyper_util::service::TowerToHyperService::new(
axum::Router::new().fallback(axum::routing::method_routing::any(
move |req: Request| async move {
match async move {
let host = req
.headers()
.get(http::header::HOST)
.and_then(|host| host.to_str().ok());
if let Some(host) = host {
let uri = Uri::from_parts({
let mut parts =
req.uri().to_owned().into_parts();
parts.scheme = Some("https".parse()?);
parts.authority = Some(host.parse()?);
parts
})?;
Response::builder()
.status(http::StatusCode::TEMPORARY_REDIRECT)
.header(http::header::LOCATION, uri.to_string())
.body(Body::default())
} else {
Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(Body::from("Host header required"))
}
}
.await
{
Ok(a) => a,
Err(e) => {
tracing::warn!(
"Error redirecting http request on ssl port: {e}"
);
tracing::error!("{e:?}");
server_error(Error::new(e, ErrorKind::Network))
}
}
},
)),
),
)
.await
.map_err(|e| {
Error::new(color_eyre::eyre::Report::msg(e), ErrorKind::Network)
});
} else {
return Err(e).with_kind(ErrorKind::Network);
}
}
};
let target_name: Option<InternedString> =
mid.client_hello().server_name().map(|s| s.into());
if let Some(domain) = target_name.as_ref() {
if mid
.client_hello()
.alpn()
.into_iter()
.flatten()
.any(|alpn| alpn == ACME_TLS_ALPN_NAME)
{
let cert = WatchStream::new(
acme_tls_alpn_cache
.peek(|c| c.get(&**domain).cloned())
.ok_or_else(|| {
Error::new(
eyre!("No challenge recv available for {domain}"),
ErrorKind::OpenSsl,
)
})?,
);
tracing::info!("Waiting for verification cert for {domain}");
let cert = cert
.filter(|c| c.is_some())
.next()
.await
.flatten()
.ok_or_else(|| {
Error::new(
eyre!("No challenge available for {domain}"),
ErrorKind::OpenSsl,
)
})?;
tracing::info!("Verification cert received for {domain}");
let mut cfg = ServerConfig::builder_with_provider(crypto_provider.clone())
.with_safe_default_protocol_versions()
.with_kind(crate::ErrorKind::OpenSsl)?
.with_no_client_auth()
.with_cert_resolver(Arc::new(SingleCertResolver(cert)));
cfg.alpn_protocols = vec![ACME_TLS_ALPN_NAME.to_vec()];
tracing::info!("performing ACME auth challenge");
let mut accept = mid.into_stream(Arc::new(cfg));
let io = accept.get_mut().unwrap();
let buffered = io.stop_buffering();
io.write_all(&buffered).await?;
accept.await?;
tracing::info!("ACME auth challenge completed");
return Ok(());
}
}
let target = {
let m = mapping.borrow();
m.get(&target_name)
.into_iter()
.flatten()
.find(|(_, rc)| rc.strong_count() > 0)
.or_else(|| {
if target_name
.as_ref()
.map(|s| s.parse::<IpAddr>().is_ok())
.unwrap_or(true)
{
m.get(&None)
.into_iter()
.flatten()
.find(|(_, rc)| rc.strong_count() > 0)
} else {
None
}
})
.map(|(target, _)| target.clone())
};
if let Some(target) = target {
if !check_filter(bind, &target.filter) {
log::warn!("Connection from {bind} to {target:?} rejected by filter");
return Ok(());
}
let peek = db.peek().await;
let root = peek
.as_private()
.as_key_store()
.as_local_certs()
.as_root_cert()
.de()?;
let mut cfg = async {
if let Some((domain, provider, settings)) =
target_name.as_ref().and_then(|domain| {
target.acme.as_ref().and_then(|a| {
peek.as_public()
.as_server_info()
.as_network()
.as_acme()
.as_idx(a)
.map(|s| (domain, a, s))
})
})
{
let acme_settings = settings.de()?;
let mut identifiers = vec![Identifier::Dns(domain.to_string())];
if false
// Requires RFC 8738
{
if let Some(wan_ip) = wan_ip {
identifiers.push(Identifier::Ip(wan_ip.into()));
}
}
let (send, recv) = watch::channel(None);
acme_tls_alpn_cache.mutate(|c| c.insert(domain.clone(), recv));
let cert = async_acme::rustls_helper::order(
|_, cert| {
send.send_replace(Some(Arc::new(cert)));
Ok(())
},
provider.0.as_str(),
&identifiers,
Some(&AcmeCertCache(&db)),
&acme_settings.contact,
)
.await
.with_kind(ErrorKind::OpenSsl)?;
return Ok(ServerConfig::builder_with_provider(crypto_provider.clone())
.with_safe_default_protocol_versions()
.with_kind(crate::ErrorKind::OpenSsl)?
.with_no_client_auth()
.with_cert_resolver(Arc::new(SingleCertResolver(Arc::new(cert)))));
}
let hostnames = target_name
.into_iter()
.chain([InternedString::from_display(&bind.ip())])
.chain(wan_ip.as_ref().map(InternedString::from_display))
.collect();
let key = db
.mutate(|v| {
v.as_private_mut()
.as_key_store_mut()
.as_local_certs_mut()
.cert_for(&hostnames)
})
.await
.result?;
let cfg = ServerConfig::builder_with_provider(crypto_provider.clone())
.with_safe_default_protocol_versions()
.with_kind(crate::ErrorKind::OpenSsl)?
.with_no_client_auth();
if mid
.client_hello()
.signature_schemes()
.contains(&tokio_rustls::rustls::SignatureScheme::ED25519)
{
cfg.with_single_cert(
key.fullchain_ed25519()
.into_iter()
.map(|c| {
Ok(tokio_rustls::rustls::pki_types::CertificateDer::from(
c.to_der()?,
))
})
.collect::<Result<_, Error>>()?,
PrivateKeyDer::from(PrivatePkcs8KeyDer::from(
key.leaf.keys.ed25519.private_key_to_pkcs8()?,
)),
)
} else {
cfg.with_single_cert(
key.fullchain_nistp256()
.into_iter()
.map(|c| {
Ok(tokio_rustls::rustls::pki_types::CertificateDer::from(
c.to_der()?,
))
})
.collect::<Result<_, Error>>()?,
PrivateKeyDer::from(PrivatePkcs8KeyDer::from(
key.leaf.keys.nistp256.private_key_to_pkcs8()?,
)),
)
}
.with_kind(crate::ErrorKind::OpenSsl)
}
.await?;
let mut tcp_stream = TcpStream::connect(target.addr).await?;
match target.connect_ssl {
Ok(()) => {
let mut client_cfg =
tokio_rustls::rustls::ClientConfig::builder_with_provider(crypto_provider)
.with_safe_default_protocol_versions()
.with_kind(crate::ErrorKind::OpenSsl)?
.with_root_certificates({
let mut store = RootCertStore::empty();
store
.add(CertificateDer::from(root.to_der()?))
.with_kind(crate::ErrorKind::OpenSsl)?;
store
})
.with_no_client_auth();
client_cfg.alpn_protocols = mid
.client_hello()
.alpn()
.into_iter()
.flatten()
.map(|x| x.to_vec())
.collect();
let mut target_stream = TlsConnector::from(Arc::new(client_cfg))
.connect_with(
ServerName::IpAddress(target.addr.ip().into()),
tcp_stream,
|conn| {
cfg.alpn_protocols
.extend(conn.alpn_protocol().into_iter().map(|p| p.to_vec()))
},
)
.await
.with_kind(crate::ErrorKind::OpenSsl)?;
let mut accept = mid.into_stream(Arc::new(cfg));
let io = accept.get_mut().unwrap();
let buffered = io.stop_buffering();
io.write_all(&buffered).await?;
let mut tls_stream = match accept.await {
Ok(a) => a,
Err(e) => {
tracing::trace!(
"VHostController: failed to accept TLS connection on {bind}: {e}"
);
tracing::trace!("{e:?}");
return Ok(());
}
};
tokio::io::copy_bidirectional(&mut tls_stream, &mut target_stream).await
}
Err(AlpnInfo::Reflect) => {
for proto in mid.client_hello().alpn().into_iter().flatten() {
cfg.alpn_protocols.push(proto.into());
}
let mut accept = mid.into_stream(Arc::new(cfg));
let io = accept.get_mut().unwrap();
let buffered = io.stop_buffering();
io.write_all(&buffered).await?;
let mut tls_stream = match accept.await {
Ok(a) => a,
Err(e) => {
tracing::trace!(
"VHostController: failed to accept TLS connection on {bind}: {e}"
);
tracing::trace!("{e:?}");
return Ok(());
}
};
tokio::io::copy_bidirectional(&mut tls_stream, &mut tcp_stream).await
}
Err(AlpnInfo::Specified(alpn)) => {
cfg.alpn_protocols = alpn.into_iter().map(|a| a.0).collect();
let mut accept = mid.into_stream(Arc::new(cfg));
let io = accept.get_mut().unwrap();
let buffered = io.stop_buffering();
io.write_all(&buffered).await?;
let mut tls_stream = match accept.await {
Ok(a) => a,
Err(e) => {
tracing::trace!(
"VHostController: failed to accept TLS connection on {bind}: {e}"
);
tracing::trace!("{e:?}");
return Ok(());
}
};
tokio::io::copy_bidirectional(&mut tls_stream, &mut tcp_stream).await
}
}
.map_or_else(
|e| {
use std::io::ErrorKind as E;
match e.kind() {
E::UnexpectedEof
| E::BrokenPipe
| E::ConnectionAborted
| E::ConnectionReset
| E::ConnectionRefused
| E::TimedOut
| E::Interrupted
| E::NotConnected => Ok(()),
_ => Err(e),
}
},
|_| Ok(()),
)?;
} else {
// 503
}
Ok::<_, Error>(())
}
impl<A: Accept> VHostServer<A> {
#[instrument(skip_all)]
fn new<A: Accept>(
fn new<M: HasModel>(
listener: A,
db: TypedPatchDb<Database>,
db: TypedPatchDb<M>,
crypto_provider: Arc<CryptoProvider>,
acme_cache: AcmeTlsAlpnCache,
) -> Result<Self, Error> {
) -> Self
where
for<'a> M: HasModel<Model = Model<M>>
+ DbAccessMut<CertStore>
+ DbAccessMut<AcmeCertStore>
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ Send
+ Sync
+ 'static,
A: Accept + Send + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
+ Clone
+ Send
+ Sync
+ 'static,
{
let mapping = Watch::new(BTreeMap::new());
Ok(Self {
Self {
mapping: mapping.clone(),
_thread: tokio::spawn(async move {
let listener = TlsListener::new(
let mut listener = VHostListener(TlsListener::new(
listener,
VHostTlsHandler {
cert_handler: ChainedHandler(
&AcmeTlsHandler {
db: &db,
acme_cache: &acme_cache,
crypto_provider: &crypto_provider,
get_provider: todo!(),
TlsHandlerWrapper {
inner: ChainedHandler(
Arc::new(AcmeTlsHandler {
db: db.clone(),
acme_cache,
crypto_provider: crypto_provider.clone(),
get_provider: GetVHostAcmeProvider(mapping.clone()),
in_progress: Watch::new(BTreeSet::new()),
}),
RootCaTlsHandler {
db,
crypto_provider,
},
todo!(),
),
alpn_handler: todo!(),
wrapper: VHostConnector(mapping, None),
},
);
));
loop {
if let Err(e) = Self::accept(&mut listener, &mapping).await {
tracing::error!("VHostController: failed to accept connection: {e}");
if let Err(e) = listener.handle_next().await {
tracing::error!("VHostServer: failed to accept connection: {e}");
tracing::debug!("{e:?}");
}
}
})
.into(),
})
}
}
fn add(&self, hostname: Option<InternedString>, target: ProxyTarget) -> Result<Arc<()>, Error> {
fn add(
&self,
hostname: Option<InternedString>,
target: DynVHostTarget<A>,
) -> Result<Arc<()>, Error> {
let target = target.into();
let mut res = Ok(Arc::new(()));
self.mapping.send_if_modified(|writable| {
let mut changed = false;
@@ -888,7 +642,7 @@ impl VHostServer {
res = Ok(rc);
changed
});
if !self.mapping.is_closed() {
if self.mapping.watcher_count() > 1 {
res
} else {
Err(Error::new(
@@ -913,6 +667,6 @@ impl VHostServer {
});
}
fn is_empty(&self) -> bool {
self.mapping.borrow().is_empty()
self.mapping.peek(|m| m.is_empty())
}
}

View File

@@ -1,21 +1,22 @@
use std::any::Any;
use std::collections::BTreeMap;
use std::future::Future;
use std::net::SocketAddr;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use std::task::{ready, Poll};
use std::time::Duration;
use axum::Router;
use futures::future::Either;
use futures::FutureExt;
use futures::{FutureExt, TryFutureExt};
use helpers::NonDetachingJoinHandle;
use http::Extensions;
use hyper_util::rt::{TokioIo, TokioTimer};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use visit_rs::{Visit, Visitor};
use visit_rs::{Visit, VisitFields, Visitor};
use crate::net::static_server::{ui_router, UiContext};
use crate::prelude::*;
@@ -38,6 +39,16 @@ impl<'a> MetadataVisitor for ExtensionVisitor<'a> {
self.0.insert(metadata.clone());
}
}
impl<'a> Visit<ExtensionVisitor<'a>>
for Box<dyn for<'x> Visit<ExtensionVisitor<'x>> + Send + Sync + 'static>
{
fn visit(
&self,
visitor: &mut ExtensionVisitor<'a>,
) -> <ExtensionVisitor<'a> as Visitor>::Result {
(&**self).visit(visitor)
}
}
pub struct ExtractVisitor<T>(Option<T>);
impl<T> Visitor for ExtractVisitor<T> {
@@ -50,7 +61,12 @@ impl<T: Clone + Send + Sync + 'static> MetadataVisitor for ExtractVisitor<T> {
}
}
}
pub fn extract<T, M: Visit<ExtractVisitor<T>>>(metadata: &M) -> Option<T> {
pub fn extract<
T: Clone + Send + Sync + 'static,
M: Visit<ExtractVisitor<T>> + Clone + Send + Sync + 'static,
>(
metadata: &M,
) -> Option<T> {
let mut visitor = ExtractVisitor(None);
visitor.visit(metadata);
visitor.0
@@ -73,6 +89,13 @@ pub trait Accept {
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>>;
fn into_dyn(self) -> DynAccept
where
Self: Sized + Send + Sync + 'static,
for<'a> Self::Metadata: Visit<ExtensionVisitor<'a>> + Send + Sync + 'static,
{
DynAccept::new(self)
}
}
impl Accept for TcpListener {
@@ -121,6 +144,47 @@ where
}
}
#[derive(Clone, VisitFields)]
pub struct MapListenerMetadata<K, M> {
pub inner: M,
pub key: K,
}
impl<K, M, V> Visit<V> for MapListenerMetadata<K, M>
where
V: MetadataVisitor,
K: Visit<V>,
M: Visit<V>,
{
fn visit(&self, visitor: &mut V) -> <V as Visitor>::Result {
self.visit_fields(visitor).collect()
}
}
impl<K, A> Accept for BTreeMap<K, A>
where
K: Clone,
A: Accept,
{
type Metadata = MapListenerMetadata<K, A::Metadata>;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
for (key, listener) in self {
if let Poll::Ready((metadata, stream)) = listener.poll_accept(cx)? {
return Poll::Ready(Ok((
MapListenerMetadata {
inner: metadata,
key: key.clone(),
},
stream,
)));
}
}
Poll::Pending
}
}
impl<A, B> Accept for Either<A, B>
where
A: Accept,
@@ -150,6 +214,68 @@ impl<A: Accept> Accept for Option<A> {
}
}
trait DynAcceptT: Send + Sync {
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<
Result<
(
Box<dyn for<'a> Visit<ExtensionVisitor<'a>> + Send + Sync>,
AcceptStream,
),
Error,
>,
>;
}
impl<A> DynAcceptT for A
where
A: Accept + Send + Sync,
for<'a> <A as Accept>::Metadata: Visit<ExtensionVisitor<'a>> + Send + Sync + 'static,
{
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<
Result<
(
Box<dyn for<'a> Visit<ExtensionVisitor<'a>> + Send + Sync>,
AcceptStream,
),
Error,
>,
> {
let (metadata, stream) = ready!(Accept::poll_accept(self, cx)?);
Poll::Ready(Ok((Box::new(metadata), stream)))
}
}
pub struct DynAccept(Box<dyn DynAcceptT>);
impl Accept for DynAccept {
type Metadata = Box<dyn for<'a> Visit<ExtensionVisitor<'a>> + Send + Sync>;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
DynAcceptT::poll_accept(&mut *self.0, cx)
}
fn into_dyn(self) -> DynAccept
where
Self: Sized,
for<'a> Self::Metadata: Visit<ExtensionVisitor<'a>> + Send + Sync + 'static,
{
self
}
}
impl DynAccept {
pub fn new<A>(accept: A) -> Self
where
A: Accept + Send + Sync + 'static,
for<'a> <A as Accept>::Metadata: Visit<ExtensionVisitor<'a>> + Send + Sync + 'static,
{
Self(Box::new(accept))
}
}
#[pin_project::pin_project]
pub struct Acceptor<A: Accept> {
acceptor: Watch<A>,
@@ -184,6 +310,64 @@ impl Acceptor<Vec<TcpListener>> {
))
}
}
impl Acceptor<Vec<DynAccept>> {
pub async fn bind_dyn(listen: impl IntoIterator<Item = SocketAddr>) -> Result<Self, Error> {
Ok(Self::new(
futures::future::try_join_all(
listen
.into_iter()
.map(TcpListener::bind)
.map(|f| f.map_ok(DynAccept::new)),
)
.await?,
))
}
}
impl<K> Acceptor<BTreeMap<K, TcpListener>>
where
K: Ord + Clone + Send + Sync + 'static,
{
pub async fn bind_map(
listen: impl IntoIterator<Item = (K, SocketAddr)>,
) -> Result<Self, Error> {
Ok(Self::new(
futures::future::try_join_all(listen.into_iter().map(|(key, addr)| async move {
Ok::<_, Error>((
key,
TcpListener::bind(addr)
.await
.with_kind(ErrorKind::Network)?,
))
}))
.await?
.into_iter()
.collect(),
))
}
}
impl<K> Acceptor<BTreeMap<K, DynAccept>>
where
K: Ord + Clone + Send + Sync + 'static,
{
pub async fn bind_map_dyn(
listen: impl IntoIterator<Item = (K, SocketAddr)>,
) -> Result<Self, Error> {
Ok(Self::new(
futures::future::try_join_all(listen.into_iter().map(|(key, addr)| async move {
Ok::<_, Error>((
key,
TcpListener::bind(addr)
.await
.with_kind(ErrorKind::Network)?,
))
}))
.await?
.into_iter()
.map(|(key, listener)| (key, listener.into_dyn()))
.collect(),
))
}
}
pub struct WebServerAcceptorSetter<A: Accept> {
acceptor: Watch<A>,

View File

@@ -3,7 +3,7 @@ use std::collections::{BTreeMap, BTreeSet};
use axum::Router;
use futures::future::ready;
use models::DataUrl;
use rpc_toolkit::{Context, HandlerExt, ParentHandler, Server, from_fn_async};
use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler, Server};
use serde::{Deserialize, Serialize};
use ts_rs::TS;
@@ -141,9 +141,3 @@ pub fn registry_router(ctx: RegistryContext) -> Router {
}),
)
}
impl<A: Accept + Send + Sync + 'static> WebServer<A> {
pub fn serve_registry(&mut self, ctx: RegistryContext) {
self.serve_router(registry_router(ctx))
}
}

View File

@@ -6,11 +6,11 @@ use ipnet::IpNet;
use itertools::Itertools;
use openssl::pkey::{PKey, Private};
use crate::HOST_IP;
use crate::service::effects::callbacks::CallbackHandler;
use crate::service::effects::prelude::*;
use crate::service::rpc::CallbackId;
use crate::util::serde::Pem;
use crate::HOST_IP;
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, TS, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
@@ -95,7 +95,7 @@ pub async fn get_ssl_certificate(
.as_entries()?
.into_iter()
.flat_map(|(_, net)| net.as_ip_info().transpose_ref())
.flat_map(|net| net.as_subnets().de().log_err())
.flat_map(|net| net.as_deref().as_subnets().de().log_err())
.flatten()
.any(|s| s.addr() == ip)
{

View File

@@ -1,12 +1,11 @@
use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::collections::BTreeMap;
use std::net::{IpAddr, SocketAddr};
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use clap::Parser;
use cookie::{Cookie, Expiration, SameSite};
use helpers::NonDetachingJoinHandle;
use http::HeaderMap;
use imbl::OrdMap;
use imbl_value::InternedString;
@@ -23,14 +22,16 @@ use url::Url;
use crate::auth::Sessions;
use crate::context::config::ContextConfig;
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::{NetworkInterfaceInfo, NetworkInterfaceType};
use crate::db::model::public::NetworkInterfaceInfo;
use crate::middleware::auth::AuthContext;
use crate::net::forward::PortForwardController;
use crate::net::gateway::{IdFilter, InterfaceFilter, NetworkInterfaceWatcher};
use crate::net::gateway::{IdFilter, InterfaceFilter};
use crate::prelude::*;
use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations};
use crate::tunnel::db::{GatewayPort, TunnelDatabase};
use crate::tunnel::wg::WIREGUARD_INTERFACE_NAME;
use crate::tunnel::TUNNEL_DEFAULT_LISTEN;
use crate::util::collections::OrdMapIterMut;
use crate::util::io::read_file_to_string;
use crate::util::sync::{SyncMutex, Watch};
use crate::util::Invoke;
@@ -100,7 +101,24 @@ impl TunnelContext {
)
.await?;
let listen = config.tunnel_listen.unwrap_or(TUNNEL_DEFAULT_LISTEN);
let net_iface = Watch::new(crate::net::utils::load_network_interface_info().await?);
let ip_info = crate::net::utils::load_ip_info().await?;
let net_iface = db
.mutate(|db| {
db.as_gateways_mut().mutate(|g| {
for (_, v) in OrdMapIterMut::from(&mut *g) {
v.ip_info = None;
}
for (id, info) in ip_info {
if id.as_str() != WIREGUARD_INTERFACE_NAME {
g.entry(id).or_default().ip_info = Some(Arc::new(info));
}
}
Ok(g.clone())
})
})
.await
.result?;
let net_iface = Watch::new(net_iface);
let forward = PortForwardController::new(net_iface.clone_unseen());
Command::new("sysctl")
@@ -111,12 +129,8 @@ impl TunnelContext {
for iface in net_iface.peek(|i| {
i.iter()
.filter(|(_, info)| {
dbg!(info).ip_info.as_ref().map_or(false, |i| {
dbg!(i).device_type != Some(NetworkInterfaceType::Wireguard)
})
})
.map(|(name, _)| name)
.filter(|id| id.as_str() != WIREGUARD_INTERFACE_NAME)
.cloned()
.collect::<Vec<_>>()
}) {

View File

@@ -1,15 +1,13 @@
use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, SocketAddr, SocketAddrV4};
use std::collections::BTreeMap;
use std::net::{SocketAddr, SocketAddrV4};
use std::path::PathBuf;
use clap::builder::ValueParserFactory;
use clap::Parser;
use imbl::HashMap;
use imbl::{HashMap, OrdMap};
use imbl_value::InternedString;
use itertools::Itertools;
use models::{FromStrParser, GatewayId};
use openssl::pkey::{PKey, Private};
use openssl::x509::X509;
use patch_db::json_ptr::{JsonPointer, ROOT};
use patch_db::Dump;
use rpc_toolkit::yajrc::RpcError;
@@ -20,16 +18,14 @@ use ts_rs::TS;
use crate::auth::Sessions;
use crate::context::CliContext;
use crate::net::ssl::FullchainCertData;
use crate::db::model::public::NetworkInterfaceInfo;
use crate::prelude::*;
use crate::sign::AnyVerifyingKey;
use crate::tunnel::auth::SignerInfo;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::web::TunnelCertData;
use crate::tunnel::wg::WgServer;
use crate::util::serde::{
apply_expr, deserialize_from_str, serialize_display, HandlerExtSerde, Pem,
};
use crate::util::serde::{apply_expr, deserialize_from_str, serialize_display, HandlerExtSerde};
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct GatewayPort(pub GatewayId, pub u16);
@@ -84,7 +80,8 @@ pub struct TunnelDatabase {
pub sessions: Sessions,
pub password: Option<String>,
pub auth_pubkeys: HashMap<AnyVerifyingKey, SignerInfo>,
pub certificates: BTreeMap<JsonKey<BTreeSet<InternedString>>, TunnelCertData>,
pub certificates: Option<TunnelCertData>,
pub gateways: OrdMap<GatewayId, NetworkInterfaceInfo>,
pub wg: WgServer,
pub port_forwards: PortForwards,
}

View File

@@ -79,9 +79,3 @@ pub fn tunnel_router(ctx: TunnelContext) -> Router {
}),
)
}
impl<A: Accept + Send + Sync + 'static> WebServer<A> {
pub fn serve_tunnel(&mut self, ctx: TunnelContext) {
self.serve_router(tunnel_router(ctx))
}
}

View File

@@ -10,11 +10,16 @@ use openssl::x509::X509;
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerArgs, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio_rustls::rustls::server::ClientHello;
use tokio_rustls::rustls::ServerConfig;
use crate::context::CliContext;
use crate::net::ssl::SANInfo;
use crate::net::tls::TlsHandler;
use crate::net::web_server::Accept;
use crate::prelude::*;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::db::TunnelDatabase;
use crate::util::serde::Pem;
#[derive(Debug, Deserialize, Serialize, Parser)]
@@ -26,6 +31,19 @@ pub struct TunnelCertData {
pub cert: Pem<Vec<X509>>,
}
pub struct TunnelCertHandler(TypedPatchDb<TunnelDatabase>);
impl<'a, A> TlsHandler<'a, A> for TunnelCertHandler
where
A: Accept,
{
async fn get_config(
&'a mut self,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig> {
}
}
pub fn web_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand("init", from_fn_async(init_web_rpc).no_cli())

View File

@@ -13,6 +13,8 @@ use crate::util::io::write_file_atomic;
use crate::util::serde::Base64;
use crate::util::Invoke;
pub const WIREGUARD_INTERFACE_NAME: &str = "wg-start-tunnel";
#[derive(Deserialize, Serialize, HasModel)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
@@ -37,7 +39,7 @@ impl WgServer {
pub async fn sync(&self) -> Result<(), Error> {
Command::new("wg-quick")
.arg("down")
.arg("wg0")
.arg(WIREGUARD_INTERFACE_NAME)
.invoke(ErrorKind::Network)
.await
.or_else(|e| {
@@ -49,13 +51,13 @@ impl WgServer {
}
})?;
write_file_atomic(
"/etc/wireguard/wg0.conf",
const_format::formatcp!("/etc/wireguard/{WIREGUARD_INTERFACE_NAME}.conf"),
self.server_config().to_string().as_bytes(),
)
.await?;
Command::new("wg-quick")
.arg("up")
.arg("wg0")
.arg(WIREGUARD_INTERFACE_NAME)
.invoke(ErrorKind::Network)
.await?;
Ok(())

View File

@@ -266,6 +266,9 @@ impl<T> Watch<T> {
version: 0,
}
}
pub fn watcher_count(&self) -> usize {
Arc::strong_count(&self.shared)
}
#[cfg_attr(feature = "unstable", inline(never))]
pub fn poll_changed(&mut self, cx: &mut std::task::Context<'_>) -> Poll<()> {
self.shared.mutate(|shared| {

View File

@@ -75,6 +75,8 @@ impl VersionT for Version {
}
fix_host(&mut db["public"]["serverInfo"]["network"]["host"])?;
db["private"]["keyStore"]["localCerts"] = db["private"]["keyStore"]["local_certs"].clone();
Ok(Value::Null)
}
fn down(self, _db: &mut Value) -> Result<(), Error> {

View File

@@ -1,3 +1,8 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
export type NetworkInterfaceType = "ethernet" | "wireless" | "wireguard"
export type NetworkInterfaceType =
| "ethernet"
| "wireless"
| "bridge"
| "wireguard"
| "loopback"

View File

@@ -1,3 +1,9 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { AnyVerifyingKey } from "./AnyVerifyingKey"
import type { ContactInfo } from "./ContactInfo"
export type SignerInfo = { name: string }
export type SignerInfo = {
name: string
contact: Array<ContactInfo>
keys: Array<AnyVerifyingKey>
}

View File

@@ -80,6 +80,8 @@ export const PRIVATE_IPV4_RANGES = [
new IpNet("192.168.0.0/16"),
]
export const IPV4_LOOPBACK = new IpNet("127.0.0.0/8")
export const IPV6_LOOPBACK = new IpNet("::1/128")
export const IPV6_LINK_LOCAL = new IpNet("fe80::/10")
export const CGNAT = new IpNet("100.64.0.0/10")

View File

@@ -24,6 +24,11 @@ export class GatewayService {
map(gateways =>
Object.entries(gateways)
.filter(([_, val]) => !!val?.ipInfo)
.filter(
([_, val]) =>
val?.ipInfo?.deviceType !== 'bridge' &&
val?.ipInfo?.deviceType !== 'loopback',
)
.map(([id, val]) => {
const subnets =
val.ipInfo?.subnets.map(s => utils.IpNet.parse(s)) ?? []