diff --git a/core/Cargo.lock b/core/Cargo.lock index 6f1100238..76b377fab 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -156,6 +156,15 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" +[[package]] +name = "ansi-width" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "219e3ce6f2611d83b51ec2098a12702112c29e57203a6b0a0929b2cddb486608" +dependencies = [ + "unicode-width 0.1.14", +] + [[package]] name = "anstream" version = "0.6.21" @@ -1777,16 +1786,18 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crossterm" -version = "0.27.0" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" +checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b" dependencies = [ "bitflags 2.10.0", "crossterm_winapi", + "derive_more 2.0.1", + "document-features", "futures-core", - "libc", - "mio 0.8.11", + "mio", "parking_lot", + "rustix 1.1.2", "signal-hook", "signal-hook-mio", "winapi", @@ -4535,26 +4546,14 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.11" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" dependencies = [ "libc", "log", "wasi 0.11.1+wasi-snapshot-preview1", - "windows-sys 0.48.0", -] - -[[package]] -name = "mio" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d83b0086dc8ecf3ce9ae2874b2d1290252e2a30720bea58a5c6639b0092873" -dependencies = [ - "libc", - "log", - "wasi 0.11.1+wasi-snapshot-preview1", - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -4726,7 +4725,7 @@ dependencies = [ "kqueue", "libc", "log", - "mio 1.1.0", + "mio", "notify-types", "walkdir", "windows-sys 0.60.2", @@ -6440,18 +6439,17 @@ dependencies = [ [[package]] name = "rustyline-async" -version = "0.4.2" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6eb06391513b2184f0a5405c11a4a0a5302e8be442f4c5c35267187c2b37d5" +checksum = "6e07ddce8399c61495b405dc94d4f30d01fc1c5e1238f10b9c09940678bc81ab" dependencies = [ + "ansi-width", "crossterm", - "futures-channel", "futures-util", "pin-project", "thingbuf", - "thiserror 1.0.69", + "thiserror 2.0.17", "unicode-segmentation", - "unicode-width 0.1.14", ] [[package]] @@ -6907,7 +6905,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" dependencies = [ "libc", - "mio 0.8.11", + "mio", "signal-hook", ] @@ -7338,7 +7336,7 @@ dependencies = [ "libc", "log", "mbrman", - "mio 1.1.0", + "mio", "models", "new_mime_guess", "nix 0.30.1", @@ -7793,7 +7791,7 @@ checksum = "ff360e02eab121e0bc37a2d3b4d4dc622e6eda3a8e5253d5435ecf5bd4c68408" dependencies = [ "bytes", "libc", - "mio 1.1.0", + "mio", "parking_lot", "pin-project-lite", "signal-hook-registry", diff --git a/core/models/src/errors.rs b/core/models/src/errors.rs index 826da282b..929a9502c 100644 --- a/core/models/src/errors.rs +++ b/core/models/src/errors.rs @@ -187,7 +187,6 @@ impl Display for ErrorKind { } } -#[derive(Debug)] pub struct Error { pub source: color_eyre::eyre::Error, pub debug: Option, @@ -198,7 +197,17 @@ pub struct Error { impl Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}: {}", self.kind.as_str(), self.source) + write!(f, "{}: {:#}", self.kind.as_str(), self.source) + } +} +impl Debug for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}: {:?}", + self.kind.as_str(), + self.debug.as_ref().unwrap_or(&self.source) + ) } } impl Error { diff --git a/core/startos/Cargo.toml b/core/startos/Cargo.toml index a1f641b82..ad4dc76a3 100644 --- a/core/startos/Cargo.toml +++ b/core/startos/Cargo.toml @@ -222,7 +222,7 @@ reqwest_cookie_store = "0.8.0" rpassword = "7.2.0" rpc-toolkit = { git = "https://github.com/Start9Labs/rpc-toolkit.git", branch = "master" } rust-argon2 = "2.0.0" -rustyline-async = "0.4.1" +rustyline-async = "0.4.7" safelog = { version = "0.4.8", git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit", optional = true } semver = { version = "1.0.20", features = ["serde"] } serde = { version = "1.0", features = ["derive", "rc"] } diff --git a/core/startos/src/account.rs b/core/startos/src/account.rs index b86ebbbfb..14752b5da 100644 --- a/core/startos/src/account.rs +++ b/core/startos/src/account.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::time::SystemTime; use imbl_value::InternedString; @@ -120,12 +121,20 @@ impl AccountInfo { key_store.as_onion_mut().insert_key(tor_key)?; } let cert_store = key_store.as_local_certs_mut(); - cert_store - .as_root_key_mut() - .ser(Pem::new_ref(&self.root_ca_key))?; - cert_store - .as_root_cert_mut() - .ser(Pem::new_ref(&self.root_ca_cert))?; + if cert_store.as_root_cert().de()?.0 != self.root_ca_cert { + cert_store + .as_root_key_mut() + .ser(Pem::new_ref(&self.root_ca_key))?; + cert_store + .as_root_cert_mut() + .ser(Pem::new_ref(&self.root_ca_cert))?; + let int_key = crate::net::ssl::generate_key()?; + let int_cert = + crate::net::ssl::make_int_cert((&self.root_ca_key, &self.root_ca_cert), &int_key)?; + cert_store.as_int_key_mut().ser(&Pem(int_key))?; + cert_store.as_int_cert_mut().ser(&Pem(int_cert))?; + cert_store.as_leaves_mut().ser(&BTreeMap::new())?; + } Ok(()) } diff --git a/core/startos/src/bins/registry.rs b/core/startos/src/bins/registry.rs index 2aacebd39..53fd06ea7 100644 --- a/core/startos/src/bins/registry.rs +++ b/core/startos/src/bins/registry.rs @@ -6,19 +6,22 @@ use rpc_toolkit::CliApp; use tokio::signal::unix::signal; use tracing::instrument; -use crate::context::CliContext; use crate::context::config::ClientConfig; +use crate::context::CliContext; use crate::net::web_server::{Acceptor, WebServer}; use crate::prelude::*; use crate::registry::context::{RegistryConfig, RegistryContext}; +use crate::registry::registry_router; use crate::util::logger::LOGGER; #[instrument(skip_all)] async fn inner_main(config: &RegistryConfig) -> Result<(), Error> { let server = async { let ctx = RegistryContext::init(config).await?; - let mut server = WebServer::new(Acceptor::bind([ctx.listen]).await?); - server.serve_registry(ctx.clone()); + let server = WebServer::new( + Acceptor::bind([ctx.listen]).await?, + registry_router(ctx.clone()), + ); let mut shutdown_recv = ctx.shutdown.subscribe(); diff --git a/core/startos/src/bins/start_init.rs b/core/startos/src/bins/start_init.rs index 477207e91..259831315 100644 --- a/core/startos/src/bins/start_init.rs +++ b/core/startos/src/bins/start_init.rs @@ -11,7 +11,8 @@ use crate::disk::main::DEFAULT_PASSWORD; use crate::disk::REPAIR_DISK_PATH; use crate::firmware::{check_for_firmware_update, update_firmware}; use crate::init::{InitPhases, STANDBY_MODE_PATH}; -use crate::net::web_server::{UpgradableListener, WebServer}; +use crate::net::gateway::UpgradableListener; +use crate::net::web_server::WebServer; use crate::prelude::*; use crate::progress::FullProgressTracker; use crate::shutdown::Shutdown; diff --git a/core/startos/src/bins/startd.rs b/core/startos/src/bins/startd.rs index ac0531b62..86c391508 100644 --- a/core/startos/src/bins/startd.rs +++ b/core/startos/src/bins/startd.rs @@ -12,8 +12,9 @@ use tracing::instrument; use crate::context::config::ServerConfig; use crate::context::rpc::InitRpcContextPhases; use crate::context::{DiagnosticContext, InitContext, RpcContext}; -use crate::net::gateway::SelfContainedNetworkInterfaceListener; -use crate::net::web_server::{Acceptor, UpgradableListener, WebServer}; +use crate::net::gateway::{BindTcp, SelfContainedNetworkInterfaceListener, UpgradableListener}; +use crate::net::static_server::refresher; +use crate::net::web_server::{Acceptor, WebServer}; use crate::shutdown::Shutdown; use crate::system::launch_metrics_task; use crate::util::io::append_file; @@ -147,9 +148,10 @@ pub fn main(args: impl IntoIterator) { .build() .expect("failed to initialize runtime"); let res = rt.block_on(async { - let mut server = WebServer::new(Acceptor::bind_upgradable( - SelfContainedNetworkInterfaceListener::bind(80), - )); + let mut server = WebServer::new( + Acceptor::bind_upgradable(SelfContainedNetworkInterfaceListener::bind(BindTcp, 80)), + refresher(), + ); match inner_main(&mut server, &config).await { Ok(a) => { server.shutdown().await; diff --git a/core/startos/src/bins/tunnel.rs b/core/startos/src/bins/tunnel.rs index 5e4eb0853..0afcb0cf1 100644 --- a/core/startos/src/bins/tunnel.rs +++ b/core/startos/src/bins/tunnel.rs @@ -1,4 +1,6 @@ +use std::collections::BTreeMap; use std::ffi::OsString; +use std::net::SocketAddr; use std::time::Duration; use clap::Parser; @@ -7,44 +9,77 @@ use helpers::NonDetachingJoinHandle; use rpc_toolkit::CliApp; use tokio::signal::unix::signal; use tracing::instrument; +use visit_rs::Visit; use crate::context::config::ClientConfig; use crate::context::CliContext; -use crate::net::web_server::{Acceptor, WebServer}; +use crate::net::gateway::{Bind, BindTcp}; +use crate::net::tls::{ChainedHandler, TlsListener}; +use crate::net::web_server::{Accept, Acceptor, MetadataVisitor, WebServer}; use crate::prelude::*; use crate::tunnel::context::{TunnelConfig, TunnelContext}; +use crate::tunnel::tunnel_router; use crate::util::logger::LOGGER; +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +enum WebserverListener { + Http, + Https(SocketAddr), +} +impl Visit for WebserverListener { + fn visit(&self, visitor: &mut V) -> ::Result { + visitor.visit(self) + } +} + #[instrument(skip_all)] async fn inner_main(config: &TunnelConfig) -> Result<(), Error> { let server = async { let ctx = TunnelContext::init(config).await?; let listen = ctx.listen; - let mut server = WebServer::new(Acceptor::bind([listen]).await?); + let server = WebServer::new( + Acceptor::bind_map_dyn([(WebserverListener::Http, listen)]).await?, + tunnel_router(ctx.clone()), + ); + let acceptor_setter = server.acceptor_setter(); + let https_db = ctx.db.clone(); let https_thread: NonDetachingJoinHandle<()> = tokio::spawn(async move { - let mut sub = setter_db.subscribe("/webserver".parse().unwrap()).await; + let mut sub = https_db.subscribe("/webserver".parse().unwrap()).await; while sub.recv().await.is_some() { while let Err(e) = async { - let external = setter_db.peek().await.into_webserver().de()?; - - let mut bind_err = None; - - setter.send_modify(|a| { - a.retain(|a, _| *a == listen || Some(*a) == external); - if let Some(external) = external { - if !a.contains_key(&external) { - match mio::net::TcpListener::bind(external) { + if let Some(addr) = https_db.peek().await.as_webserver().de()? { + acceptor_setter.send_if_modified(|a| { + let key = WebserverListener::Https(addr); + if !a.contains_key(&key) { + match (|| { + Ok::<_, Error>(TlsListener::new( + BindTcp.bind(addr)?, + BasicCertHandler(https_db.clone()), + )) + })() { Ok(l) => { - a.insert(external, TcpListener::from_std(l.into())); - } - Err(e) => bind_err = Some(e), - } - } - } - }); + a.retain(|k, _| *k == WebserverListener::Http); + a.insert(key, l.into_dyn()); - if let Some(e) = bind_err { - return Err(e); + true + } + Err(e) => { + tracing::error!("error adding ssl listener: {e}"); + tracing::debug!("{e:?}"); + + false + } + } + } else { + false + } + }); + } else { + acceptor_setter.send_if_modified(|a| { + let before = a.len(); + a.retain(|k, _| *k == WebserverListener::Http); + a.len() != before + }); } Ok::<_, Error>(()) @@ -58,7 +93,6 @@ async fn inner_main(config: &TunnelConfig) -> Result<(), Error> { } }) .into(); - server.serve_tunnel(ctx.clone()); let mut shutdown_recv = ctx.shutdown.subscribe(); @@ -97,7 +131,7 @@ async fn inner_main(config: &TunnelConfig) -> Result<(), Error> { .with_kind(crate::ErrorKind::Unknown)?; sig_handler.wait_for_abort().await; - setter_thread.wait_for_abort().await; + https_thread.wait_for_abort().await; Ok::<_, Error>(server) } diff --git a/core/startos/src/context/cli.rs b/core/startos/src/context/cli.rs index c3a2d9151..b10af854a 100644 --- a/core/startos/src/context/cli.rs +++ b/core/startos/src/context/cli.rs @@ -172,7 +172,7 @@ impl CliContext { return Ok(secret.into()) } Err(Error::new( - eyre!("Developer Key does not exist! Please run `start-cli init` before running this command."), + eyre!("Developer Key does not exist! Please run `start-cli init-key` before running this command."), crate::ErrorKind::Uninitialized )) }) diff --git a/core/startos/src/context/rpc.rs b/core/startos/src/context/rpc.rs index 585523028..cb4114715 100644 --- a/core/startos/src/context/rpc.rs +++ b/core/startos/src/context/rpc.rs @@ -31,10 +31,11 @@ use crate::disk::OsPartitionInfo; use crate::init::{check_time_is_synchronized, InitResult}; use crate::install::PKG_ARCHIVE_DIR; use crate::lxc::LxcManager; +use crate::net::gateway::UpgradableListener; use crate::net::net_controller::{NetController, NetService}; use crate::net::socks::DEFAULT_SOCKS_LISTEN; use crate::net::utils::{find_eth_iface, find_wifi_iface}; -use crate::net::web_server::{UpgradableListener, WebServerAcceptorSetter}; +use crate::net::web_server::WebServerAcceptorSetter; use crate::net::wifi::WpaCli; use crate::prelude::*; use crate::progress::{FullProgressTracker, PhaseProgressTrackerHandle}; @@ -46,7 +47,7 @@ use crate::shutdown::Shutdown; use crate::util::io::delete_file; use crate::util::lshw::LshwDevice; use crate::util::sync::{SyncMutex, SyncRwLock, Watch}; -use crate::{DATA_DIR, HOST_IP}; +use crate::DATA_DIR; pub struct RpcContextSeed { is_closed: AtomicBool, diff --git a/core/startos/src/context/setup.rs b/core/startos/src/context/setup.rs index 94c81a9db..5594f781a 100644 --- a/core/startos/src/context/setup.rs +++ b/core/startos/src/context/setup.rs @@ -20,7 +20,8 @@ use crate::context::config::ServerConfig; use crate::context::RpcContext; use crate::disk::OsPartitionInfo; use crate::hostname::Hostname; -use crate::net::web_server::{UpgradableListener, WebServer, WebServerAcceptorSetter}; +use crate::net::gateway::UpgradableListener; +use crate::net::web_server::{WebServer, WebServerAcceptorSetter}; use crate::prelude::*; use crate::progress::FullProgressTracker; use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations}; diff --git a/core/startos/src/db/mod.rs b/core/startos/src/db/mod.rs index c0d2415fb..b271deb8d 100644 --- a/core/startos/src/db/mod.rs +++ b/core/startos/src/db/mod.rs @@ -31,13 +31,23 @@ lazy_static::lazy_static! { } pub trait DbAccess: Sized { - type Key<'a>; - fn access<'a>(db: &'a Model, key: Self::Key<'_>) -> &'a Model; + fn access<'a>(db: &'a Model) -> &'a Model; } -pub trait DbAccessMut: Sized { +pub trait DbAccessMut: DbAccess { + fn access_mut<'a>(db: &'a mut Model) -> &'a mut Model; +} + +pub trait DbAccessByKey: Sized { type Key<'a>; - fn access_mut<'a>(db: &'a mut Model, key: Self::Key<'_>) -> &'a mut Model; + fn access_by_key<'a>(db: &'a Model, key: Self::Key<'_>) -> Option<&'a Model>; +} + +pub trait DbAccessMutByKey: DbAccessByKey { + fn access_mut_by_key<'a>( + db: &'a mut Model, + key: Self::Key<'_>, + ) -> Option<&'a mut Model>; } pub fn db() -> ParentHandler { diff --git a/core/startos/src/db/model/public.rs b/core/startos/src/db/model/public.rs index 780541211..caebe5c85 100644 --- a/core/startos/src/db/model/public.rs +++ b/core/startos/src/db/model/public.rs @@ -17,6 +17,8 @@ use ts_rs::TS; use crate::account::AccountInfo; use crate::db::model::package::AllPackageData; +use crate::db::model::Database; +use crate::db::DbAccessByKey; use crate::net::acme::AcmeProvider; use crate::net::host::binding::{AddSslOptions, BindInfo, BindOptions, NetInfo}; use crate::net::host::Host; @@ -295,6 +297,19 @@ pub enum NetworkInterfaceType { pub struct AcmeSettings { pub contact: Vec, } +impl DbAccessByKey for Database { + type Key<'a> = &'a AcmeProvider; + fn access_by_key<'a>( + db: &'a Model, + key: Self::Key<'_>, + ) -> Option<&'a Model> { + db.as_public() + .as_server_info() + .as_network() + .as_acme() + .as_idx(key) + } +} #[derive(Debug, Deserialize, Serialize, HasModel, TS)] #[serde(rename_all = "camelCase")] diff --git a/core/startos/src/db/prelude.rs b/core/startos/src/db/prelude.rs index 6004c8776..f45d051cc 100644 --- a/core/startos/src/db/prelude.rs +++ b/core/startos/src/db/prelude.rs @@ -1,6 +1,7 @@ use std::collections::{BTreeMap, BTreeSet}; use std::marker::PhantomData; use std::str::FromStr; +use std::sync::Arc; use chrono::{DateTime, Utc}; use imbl::OrdMap; @@ -167,6 +168,21 @@ impl Model> { } } +impl Model> { + pub fn deref(self) -> Model { + use patch_db::ModelExt; + self.transmute(|a| a) + } + pub fn as_deref(&self) -> &Model { + use patch_db::ModelExt; + self.transmute_ref(|a| a) + } + pub fn as_deref_mut(&mut self) -> &mut Model { + use patch_db::ModelExt; + self.transmute_mut(|a| a) + } +} + pub trait Map: DeserializeOwned + Serialize { type Key; type Value; @@ -196,7 +212,7 @@ where type Key = JsonKey; type Value = B; fn key_str(key: &Self::Key) -> Result, Error> { - serde_json::to_string(key).with_kind(ErrorKind::Serialization) + serde_json::to_string(&key.0).with_kind(ErrorKind::Serialization) } } diff --git a/core/startos/src/lib.rs b/core/startos/src/lib.rs index 25b4733e0..c07f58fce 100644 --- a/core/startos/src/lib.rs +++ b/core/startos/src/lib.rs @@ -226,7 +226,7 @@ pub fn main_api() -> ParentHandler { util::rpc::util::().with_about("Command for calculating the blake3 hash of a file"), ) .subcommand( - "init", + "init-key", from_fn_async(developer::init) .no_display() .with_about("Create developer key if it doesn't exist"), @@ -241,6 +241,7 @@ pub fn main_api() -> ParentHandler { diagnostic::diagnostic::() .with_about("Commands to display logs, restart the server, etc"), ) + .subcommand("init", init::init_api::()) .subcommand("setup", setup::setup::()) .subcommand( "install", diff --git a/core/startos/src/net/acme.rs b/core/startos/src/net/acme.rs index fc1698043..22514fcb7 100644 --- a/core/startos/src/net/acme.rs +++ b/core/startos/src/net/acme.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use async_acme::acme::{Identifier, ACME_TLS_ALPN_NAME}; use clap::builder::ValueParserFactory; use clap::Parser; -use futures::StreamExt; +use futures::{FutureExt, StreamExt}; use imbl_value::InternedString; use itertools::Itertools; use models::{ErrorData, FromStrParser}; @@ -16,6 +16,7 @@ use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler}; use serde::{Deserialize, Serialize}; use tokio_rustls::rustls::crypto::CryptoProvider; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; +use tokio_rustls::rustls::server::ClientHello; use tokio_rustls::rustls::sign::CertifiedKey; use tokio_rustls::rustls::ServerConfig; use ts_rs::TS; @@ -24,7 +25,7 @@ use url::Url; use crate::context::{CliContext, RpcContext}; use crate::db::model::public::AcmeSettings; use crate::db::model::Database; -use crate::db::{DbAccess, DbAccessMut}; +use crate::db::{DbAccess, DbAccessByKey, DbAccessMut}; use crate::net::tls::{SingleCertResolver, TlsHandler}; use crate::net::web_server::Accept; use crate::prelude::*; @@ -34,28 +35,28 @@ use crate::util::sync::{SyncMutex, Watch}; pub type AcmeTlsAlpnCache = Arc>>>>>; -pub struct AcmeTlsHandler<'a, M: HasModel, S: 'a> { - pub db: &'a TypedPatchDb, - pub acme_cache: &'a AcmeTlsAlpnCache, - pub crypto_provider: &'a Arc, +pub struct AcmeTlsHandler { + pub db: TypedPatchDb, + pub acme_cache: AcmeTlsAlpnCache, + pub crypto_provider: Arc, pub get_provider: S, pub in_progress: Watch>>, } -impl<'b, M, S> AcmeTlsHandler<'b, M, S> +impl AcmeTlsHandler where - for<'a> M: DbAccess = ()> - + DbAccess = &'a AcmeProvider> - + DbAccessMut = ()> + for<'a> M: DbAccessByKey = &'a AcmeProvider> + + DbAccessMut + HasModel> + Send + Sync, - S: GetAcmeProvider<'b> + Clone + 'b, + S: GetAcmeProvider + Clone, { pub async fn get_cert(&self, san_info: &BTreeSet) -> Option { - let provider = self.get_provider.clone().get_provider(san_info).await?; + let provider = self.get_provider.get_provider(san_info).await?; + let provider = provider.as_ref(); loop { let peek = self.db.peek().await; - let store = >::access(&peek, ()); + let store = >::access(&peek); if let Some(cert) = store .as_certs() .as_idx(&provider.0) @@ -93,7 +94,7 @@ where continue; } - let contact = >::access(&peek, provider) + let contact = >::access_by_key(&peek, &provider)? .as_contact() .de() .log_err()?; @@ -141,37 +142,29 @@ where } } -pub trait GetAcmeProvider<'a> { - fn get_provider<'b>( - self, - san_info: &'b BTreeSet, - ) -> impl Future> + Send + 'b - where - Self: 'b; +pub trait GetAcmeProvider { + fn get_provider<'a, 'b: 'a>( + &'b self, + san_info: &'a BTreeSet, + ) -> impl Future + Send + 'b>> + Send + 'a; } -impl<'b, A, M, S> TlsHandler for &'b AcmeTlsHandler<'b, M, S> +impl<'a, A, M, S> TlsHandler<'a, A> for Arc> where - A: Accept, + A: Accept + 'a, ::Metadata: Send + Sync, - for<'a> M: DbAccess = ()> - + DbAccess = &'a AcmeProvider> - + DbAccessMut = ()> + for<'m> M: DbAccessByKey = &'m AcmeProvider> + + DbAccessMut + HasModel> + Send + Sync, - S: GetAcmeProvider<'b> + Clone + Send + Sync + 'b, + S: GetAcmeProvider + Clone + Send + Sync, { - async fn get_config<'a>( - self, - hello: &'a tokio_rustls::rustls::server::ClientHello<'a>, - metadata: &'a ::Metadata, - ) -> Option - where - Self: 'a, - A: 'a, - ::Metadata: 'a, - { + async fn get_config( + &'a mut self, + hello: &'a ClientHello<'a>, + _: &'a ::Metadata, + ) -> Option { let domain = hello.server_name()?; if hello .alpn() @@ -238,14 +231,12 @@ impl AcmeCertStore { } impl DbAccess for Database { - type Key<'a> = (); - fn access<'a>(db: &'a Model, _: Self::Key<'_>) -> &'a Model { + fn access<'a>(db: &'a Model) -> &'a Model { db.as_private().as_key_store().as_acme() } } impl DbAccessMut for Database { - type Key<'a> = (); - fn access_mut<'a>(db: &'a mut Model, _: Self::Key<'_>) -> &'a mut Model { + fn access_mut<'a>(db: &'a mut Model) -> &'a mut Model { db.as_private_mut().as_key_store_mut().as_acme_mut() } } @@ -260,18 +251,14 @@ pub struct AcmeCertCache<'a, M: HasModel>(pub &'a TypedPatchDb); #[async_trait::async_trait] impl<'a, M> async_acme::cache::AcmeCache for AcmeCertCache<'a, M> where - for<'b> M: HasModel> - + DbAccess = ()> - + DbAccessMut = ()> - + Send - + Sync, + M: HasModel> + DbAccessMut + Send + Sync, { type Error = ErrorData; async fn read_account(&self, contacts: &[&str]) -> Result>, Self::Error> { let contacts = JsonKey::new(contacts.into_iter().map(|s| (*s).to_owned()).collect_vec()); let peek = self.0.peek().await; - let Some(account) = M::access(&peek, ()).as_accounts().as_idx(&contacts) else { + let Some(account) = M::access(&peek).as_accounts().as_idx(&contacts) else { return Ok(None); }; Ok(Some(account.de()?.0.document.into_vec())) @@ -285,7 +272,7 @@ where }; self.0 .mutate(|db| { - M::access_mut(db, ()) + M::access_mut(db) .as_accounts_mut() .insert(&contacts, &Pem::new(key)) }) @@ -312,7 +299,7 @@ where .parse::() .with_kind(ErrorKind::ParseUrl)?; let peek = self.0.peek().await; - let Some(cert) = M::access(&peek, ()) + let Some(cert) = M::access(&peek) .as_certs() .as_idx(&directory_url) .and_then(|a| a.as_idx(&identifiers)) @@ -370,7 +357,7 @@ where }; self.0 .mutate(|db| { - M::access_mut(db, ()) + M::access_mut(db) .as_certs_mut() .upsert(&directory_url, || Ok(BTreeMap::new()))? .insert(&identifiers, &cert) @@ -443,6 +430,11 @@ impl AsRef for AcmeProvider { self.0.as_str() } } +impl AsRef for AcmeProvider { + fn as_ref(&self) -> &AcmeProvider { + self + } +} impl ValueParserFactory for AcmeProvider { type Parser = FromStrParser; fn value_parser() -> Self::Parser { diff --git a/core/startos/src/net/forward.rs b/core/startos/src/net/forward.rs index 31fb29dee..972f4f75b 100644 --- a/core/startos/src/net/forward.rs +++ b/core/startos/src/net/forward.rs @@ -377,6 +377,9 @@ impl PortForwardController { } async fn forward(interface: &str, source: SocketAddr, target: SocketAddr) -> Result<(), Error> { + if interface == START9_BRIDGE_IFACE { + return Ok(()); + } if source.is_ipv6() { return Ok(()); // TODO: socat? ip6tables? } @@ -393,6 +396,9 @@ async fn forward(interface: &str, source: SocketAddr, target: SocketAddr) -> Res } async fn unforward(interface: &str, source: SocketAddr, target: SocketAddr) -> Result<(), Error> { + if interface == START9_BRIDGE_IFACE { + return Ok(()); + } if source.is_ipv6() { return Ok(()); // TODO: socat? ip6tables? } diff --git a/core/startos/src/net/gateway.rs b/core/startos/src/net/gateway.rs index 56bb151f9..884ae9c52 100644 --- a/core/startos/src/net/gateway.rs +++ b/core/startos/src/net/gateway.rs @@ -552,6 +552,7 @@ async fn watch_ip( let managed = device_proxy.managed().await?; if !managed { + dbg!("unmanaged", &iface); return Ok(()); } let dac = device_proxy.active_connection().await?; @@ -596,10 +597,6 @@ async fn watch_ip( _ => None, }; - if device_type == Some(NetworkInterfaceType::Loopback) { - return Ok(()); - } - let name = InternedString::from(active_connection_proxy.id().await?); let dhcp4_config = active_connection_proxy.dhcp4_config().await?; @@ -676,7 +673,14 @@ async fn watch_ip( .into_iter() .map(IpNet::try_from) .try_collect()?; - let wan_ip = if !subnets.is_empty() { + let wan_ip = if !subnets.is_empty() + && !matches!( + device_type, + Some( + NetworkInterfaceType::Bridge + | NetworkInterfaceType::Loopback + ) + ) { match get_wan_ipv4(iface.as_str()).await { Ok(a) => a, Err(e) => { @@ -1535,6 +1539,17 @@ pub struct NetworkInterfaceListenerAcceptMetadata { pub inner: ::Metadata, pub info: GatewayInfo, } +impl Clone for NetworkInterfaceListenerAcceptMetadata +where + ::Metadata: Clone, +{ + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + info: self.info.clone(), + } + } +} impl Visit for NetworkInterfaceListenerAcceptMetadata where B: Bind, diff --git a/core/startos/src/net/keys.rs b/core/startos/src/net/keys.rs index 2cfcb025d..41e96eedd 100644 --- a/core/startos/src/net/keys.rs +++ b/core/startos/src/net/keys.rs @@ -8,6 +8,7 @@ use crate::prelude::*; #[derive(Debug, Deserialize, Serialize, HasModel)] #[model = "Model"] +#[serde(rename_all = "camelCase")] pub struct KeyStore { pub onion: OnionStore, pub local_certs: CertStore, diff --git a/core/startos/src/net/net_controller.rs b/core/startos/src/net/net_controller.rs index bc2cf81b6..a6316bc10 100644 --- a/core/startos/src/net/net_controller.rs +++ b/core/startos/src/net/net_controller.rs @@ -9,9 +9,10 @@ use ipnet::IpNet; use models::{GatewayId, HostId, OptionExt, PackageId}; use tokio::sync::Mutex; use tokio::task::JoinHandle; +use tokio_rustls::rustls::ClientConfig as TlsClientConfig; use tracing::instrument; -use crate::db::model::public::{NetworkInterfaceInfo, NetworkInterfaceType}; +use crate::db::model::public::NetworkInterfaceType; use crate::db::model::Database; use crate::error::ErrorCollection; use crate::hostname::Hostname; @@ -28,7 +29,7 @@ use crate::net::service_interface::{GatewayInfo, HostnameInfo, IpHostname, Onion use crate::net::socks::SocksController; use crate::net::tor::{OnionAddress, TorController, TorSecretKey}; use crate::net::utils::ipv6_is_local; -use crate::net::vhost::{AlpnInfo, ProxyTarget, VHostController}; +use crate::net::vhost::{AlpnInfo, DynVHostTarget, ProxyTarget, VHostController}; use crate::prelude::*; use crate::service::effects::callbacks::ServiceCallbacks; use crate::util::serde::MaybeUtf8String; @@ -38,6 +39,7 @@ pub struct NetController { pub(crate) db: TypedPatchDb, pub(super) tor: TorController, pub(super) vhost: VHostController, + pub(super) tls_client_config: Arc, pub(crate) net_iface: Arc, pub(super) dns: DnsController, pub(super) forward: PortForwardController, @@ -55,10 +57,24 @@ impl NetController { let net_iface = Arc::new(NetworkInterfaceController::new(db.clone())); let tor = TorController::new()?; let socks = SocksController::new(socks_listen, tor.clone())?; + let crypto_provider = Arc::new(tokio_rustls::rustls::crypto::ring::default_provider()); + let tls_client_config = Arc::new(crate::net::tls::client_config( + crypto_provider.clone(), + [&*db + .peek() + .await + .as_private() + .as_key_store() + .as_local_certs() + .as_root_cert() + .de()? + .0], + )?); Ok(Self { db: db.clone(), tor, - vhost: VHostController::new(db.clone(), net_iface.clone()), + vhost: VHostController::new(db.clone(), net_iface.clone(), crypto_provider), + tls_client_config, dns: DnsController::init(db, &net_iface.watcher).await?, forward: PortForwardController::new(net_iface.watcher.subscribe()), net_iface, @@ -267,7 +283,9 @@ impl NetServiceData { filter: bind.net.clone().into_dyn(), acme: None, addr, - connect_ssl: connect_ssl.clone(), + connect_ssl: connect_ssl + .clone() + .map(|_| ctrl.tls_client_config.clone()), }, ); } @@ -288,7 +306,9 @@ impl NetServiceData { .into_dyn(), acme: None, addr, - connect_ssl: connect_ssl.clone(), + connect_ssl: connect_ssl + .clone() + .map(|_| ctrl.tls_client_config.clone()), }, ); // TODO: wrap onion ssl stream directly in tor ctrl } @@ -315,7 +335,9 @@ impl NetServiceData { .into_dyn(), acme: public.acme.clone(), addr, - connect_ssl: connect_ssl.clone(), + connect_ssl: connect_ssl + .clone() + .map(|_| ctrl.tls_client_config.clone()), }, ); vhosts.insert( @@ -340,7 +362,9 @@ impl NetServiceData { .into_dyn(), acme: public.acme.clone(), addr, - connect_ssl: connect_ssl.clone(), + connect_ssl: connect_ssl + .clone() + .map(|_| ctrl.tls_client_config.clone()), }, ); } else { @@ -354,7 +378,9 @@ impl NetServiceData { .into_dyn(), acme: None, addr, - connect_ssl: connect_ssl.clone(), + connect_ssl: connect_ssl + .clone() + .map(|_| ctrl.tls_client_config.clone()), }, ); } @@ -379,7 +405,9 @@ impl NetServiceData { .into_dyn(), acme: public.acme.clone(), addr, - connect_ssl: connect_ssl.clone(), + connect_ssl: connect_ssl + .clone() + .map(|_| ctrl.tls_client_config.clone()), }, ); } else { @@ -393,7 +421,9 @@ impl NetServiceData { .into_dyn(), acme: None, addr, - connect_ssl: connect_ssl.clone(), + connect_ssl: connect_ssl + .clone() + .map(|_| ctrl.tls_client_config.clone()), }, ); } @@ -427,6 +457,11 @@ impl NetServiceData { hostname_info.remove(port).unwrap_or_default(); for (gateway_id, info) in net_ifaces .iter() + .filter(|(_, info)| { + info.ip_info.as_ref().map_or(false, |i| { + !matches!(i.device_type, Some(NetworkInterfaceType::Bridge)) + }) + }) .filter(|(id, info)| bind.net.filter(id, info)) { let gateway = GatewayInfo { @@ -651,7 +686,10 @@ impl NetServiceData { if let Some(prev) = prev { prev } else { - (target.clone(), ctrl.vhost.add(key.0, key.1, target)?) + ( + target.clone(), + ctrl.vhost.add(key.0, key.1, DynVHostTarget::new(target))?, + ) }, ); } else { diff --git a/core/startos/src/net/ssl.rs b/core/startos/src/net/ssl.rs index afa1360f4..e5dcb74af 100644 --- a/core/startos/src/net/ssl.rs +++ b/core/startos/src/net/ssl.rs @@ -1,7 +1,8 @@ -use std::cmp::{Ordering, min}; +use std::cmp::{min, Ordering}; use std::collections::{BTreeMap, BTreeSet}; use std::net::IpAddr; use std::path::Path; +use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use futures::FutureExt; @@ -14,18 +15,32 @@ use openssl::error::ErrorStack; use openssl::hash::MessageDigest; use openssl::nid::Nid; use openssl::pkey::{PKey, Private}; -use openssl::x509::{X509, X509Builder, X509Extension, X509NameBuilder}; +use openssl::x509::extension::{ + AuthorityKeyIdentifier, BasicConstraints, KeyUsage, SubjectAlternativeName, + SubjectKeyIdentifier, +}; +use openssl::x509::{X509Builder, X509NameBuilder, X509}; use openssl::*; use patch_db::HasModel; use serde::{Deserialize, Serialize}; +use tokio_rustls::rustls::crypto::CryptoProvider; +use tokio_rustls::rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer}; +use tokio_rustls::rustls::server::ClientHello; +use tokio_rustls::rustls::ServerConfig; use tracing::instrument; +use visit_rs::Visit; -use crate::SOURCE_DATE; use crate::account::AccountInfo; +use crate::db::model::Database; +use crate::db::{DbAccess, DbAccessMut}; use crate::hostname::Hostname; use crate::init::check_time_is_synchronized; +use crate::net::gateway::GatewayInfo; +use crate::net::tls::TlsHandler; +use crate::net::web_server::{extract, Accept, ExtractVisitor, TcpMetadata}; use crate::prelude::*; use crate::util::serde::Pem; +use crate::SOURCE_DATE; pub fn gen_nistp256() -> Result, ErrorStack> { PKey::from_ec_key(EcKey::generate(&*EcGroup::from_curve_name( @@ -130,6 +145,16 @@ impl Model { }) } } +impl DbAccess for Database { + fn access<'a>(db: &'a Model) -> &'a Model { + db.as_private().as_key_store().as_local_certs() + } +} +impl DbAccessMut for Database { + fn access_mut<'a>(db: &'a mut Model) -> &'a mut Model { + db.as_private_mut().as_key_store_mut().as_local_certs_mut() + } +} #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] pub struct CertData { @@ -305,22 +330,16 @@ pub fn make_root_cert( let cfg = conf::Conf::new(conf::ConfMethod::default())?; let ctx = builder.x509v3_context(None, Some(&cfg)); // subjectKeyIdentifier = hash - let subject_key_identifier = - X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_KEY_IDENTIFIER, "hash")?; + let subject_key_identifier = SubjectKeyIdentifier::new().build(&ctx)?; // basicConstraints = critical, CA:true, pathlen:0 - let basic_constraints = X509Extension::new_nid( - Some(&cfg), - Some(&ctx), - Nid::BASIC_CONSTRAINTS, - "critical,CA:true", - )?; + let basic_constraints = BasicConstraints::new().critical().ca().build()?; // keyUsage = critical, digitalSignature, cRLSign, keyCertSign - let key_usage = X509Extension::new_nid( - Some(&cfg), - Some(&ctx), - Nid::KEY_USAGE, - "critical,digitalSignature,cRLSign,keyCertSign", - )?; + let key_usage = KeyUsage::new() + .critical() + .digital_signature() + .crl_sign() + .key_cert_sign() + .build()?; builder.append_extension(subject_key_identifier)?; builder.append_extension(basic_constraints)?; builder.append_extension(key_usage)?; @@ -355,30 +374,23 @@ pub fn make_int_cert( let cfg = conf::Conf::new(conf::ConfMethod::default())?; let ctx = builder.x509v3_context(Some(&signer.1), Some(&cfg)); + // subjectKeyIdentifier = hash - let subject_key_identifier = - X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_KEY_IDENTIFIER, "hash")?; + let subject_key_identifier = SubjectKeyIdentifier::new().build(&ctx)?; // authorityKeyIdentifier = keyid:always,issuer - let authority_key_identifier = X509Extension::new_nid( - Some(&cfg), - Some(&ctx), - Nid::AUTHORITY_KEY_IDENTIFIER, - "keyid:always,issuer", - )?; + let authority_key_identifier = AuthorityKeyIdentifier::new() + .keyid(false) + .issuer(true) + .build(&ctx)?; // basicConstraints = critical, CA:true, pathlen:0 - let basic_constraints = X509Extension::new_nid( - Some(&cfg), - Some(&ctx), - Nid::BASIC_CONSTRAINTS, - "critical,CA:true,pathlen:0", - )?; + let basic_constraints = BasicConstraints::new().critical().ca().pathlen(0).build()?; // keyUsage = critical, digitalSignature, cRLSign, keyCertSign - let key_usage = X509Extension::new_nid( - Some(&cfg), - Some(&ctx), - Nid::KEY_USAGE, - "critical,digitalSignature,cRLSign,keyCertSign", - )?; + let key_usage = KeyUsage::new() + .critical() + .digital_signature() + .crl_sign() + .key_cert_sign() + .build()?; builder.append_extension(subject_key_identifier)?; builder.append_extension(authority_key_identifier)?; builder.append_extension(basic_constraints)?; @@ -428,6 +440,16 @@ impl SANInfo { } Self { dns, ips } } + pub fn x509_extension(&self) -> SubjectAlternativeName { + let mut san = SubjectAlternativeName::new(); + for h in &self.dns { + san.dns(&h.as_str()); + } + for ip in &self.ips { + san.ip(&ip.to_string()); + } + san + } } impl std::fmt::Display for SANInfo { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -490,28 +512,20 @@ pub fn make_leaf_cert( // Extensions let cfg = conf::Conf::new(conf::ConfMethod::default())?; let ctx = builder.x509v3_context(Some(&signer.1), Some(&cfg)); - // subjectKeyIdentifier = hash - let subject_key_identifier = - X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_KEY_IDENTIFIER, "hash")?; - // authorityKeyIdentifier = keyid:always,issuer - let authority_key_identifier = X509Extension::new_nid( - Some(&cfg), - Some(&ctx), - Nid::AUTHORITY_KEY_IDENTIFIER, - "keyid,issuer:always", - )?; - let basic_constraints = - X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::BASIC_CONSTRAINTS, "CA:FALSE")?; - let key_usage = X509Extension::new_nid( - Some(&cfg), - Some(&ctx), - Nid::KEY_USAGE, - "critical,digitalSignature,keyEncipherment", - )?; - let san_string = applicant.1.to_string(); - let subject_alt_name = - X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_ALT_NAME, &san_string)?; + let subject_key_identifier = SubjectKeyIdentifier::new().build(&ctx)?; + let authority_key_identifier = AuthorityKeyIdentifier::new() + .keyid(false) + .issuer(true) + .build(&ctx)?; + let subject_alt_name = applicant.1.x509_extension().build(&ctx)?; + let basic_constraints = BasicConstraints::new().build()?; + let key_usage = KeyUsage::new() + .critical() + .digital_signature() + .key_encipherment() + .build()?; + builder.append_extension(subject_key_identifier)?; builder.append_extension(authority_key_identifier)?; builder.append_extension(subject_alt_name)?; @@ -561,28 +575,20 @@ pub fn make_self_signed(applicant: (&PKey, &SANInfo)) -> Result, &SANInfo)) -> Result { + pub db: TypedPatchDb, + pub crypto_provider: Arc, +} +impl Clone for RootCaTlsHandler { + fn clone(&self) -> Self { + Self { + db: self.db.clone(), + crypto_provider: self.crypto_provider.clone(), + } + } +} + +impl<'a, A, M> TlsHandler<'a, A> for RootCaTlsHandler +where + A: Accept + 'a, + ::Metadata: Visit> + + Visit> + + Clone + + Send + + Sync + + 'static, + M: HasModel> + DbAccessMut + Send + Sync, +{ + async fn get_config( + &mut self, + hello: &ClientHello<'_>, + metadata: &::Metadata, + ) -> Option { + let hostnames: BTreeSet = hello + .server_name() + .map(InternedString::from) + .into_iter() + .chain( + extract::(metadata) + .map(|m| m.local_addr.ip()) + .as_ref() + .map(InternedString::from_display), + ) + .chain( + extract::(metadata) + .and_then(|i| i.info.ip_info) + .and_then(|i| i.wan_ip) + .as_ref() + .map(InternedString::from_display), + ) + .collect(); + let cert = self + .db + .mutate(|db| M::access_mut(db).cert_for(&hostnames)) + .await + .result + .log_err()?; + let cfg = ServerConfig::builder_with_provider(self.crypto_provider.clone()) + .with_safe_default_protocol_versions() + .log_err()? + .with_no_client_auth(); + if hello + .signature_schemes() + .contains(&tokio_rustls::rustls::SignatureScheme::ED25519) + { + cfg.with_single_cert( + cert.fullchain_ed25519() + .into_iter() + .map(|c| { + Ok(tokio_rustls::rustls::pki_types::CertificateDer::from( + c.to_der()?, + )) + }) + .collect::>() + .log_err()?, + PrivateKeyDer::from(PrivatePkcs8KeyDer::from( + cert.leaf.keys.ed25519.private_key_to_pkcs8().log_err()?, + )), + ) + } else { + cfg.with_single_cert( + cert.fullchain_nistp256() + .into_iter() + .map(|c| { + Ok(tokio_rustls::rustls::pki_types::CertificateDer::from( + c.to_der()?, + )) + }) + .collect::>() + .log_err()?, + PrivateKeyDer::from(PrivatePkcs8KeyDer::from( + cert.leaf.keys.nistp256.private_key_to_pkcs8().log_err()?, + )), + ) + } + .log_err() + } +} diff --git a/core/startos/src/net/tls.rs b/core/startos/src/net/tls.rs index 45284e400..c6355fade 100644 --- a/core/startos/src/net/tls.rs +++ b/core/startos/src/net/tls.rs @@ -4,10 +4,13 @@ use std::task::Poll; use futures::future::BoxFuture; use futures::FutureExt; use imbl_value::InternedString; +use openssl::x509::X509Ref; use tokio::io::AsyncWriteExt; +use tokio_rustls::rustls::crypto::CryptoProvider; +use tokio_rustls::rustls::pki_types::CertificateDer; use tokio_rustls::rustls::server::{Acceptor, ClientHello, ResolvesServerCert}; use tokio_rustls::rustls::sign::CertifiedKey; -use tokio_rustls::rustls::ServerConfig; +use tokio_rustls::rustls::{ClientConfig, RootCertStore, ServerConfig}; use tokio_rustls::LazyConfigAcceptor; use visit_rs::{Visit, VisitFields}; @@ -38,35 +41,28 @@ impl Visit for TlsHandshakeInfo { } } -pub trait TlsHandler { - fn get_config<'a>( - self, +pub trait TlsHandler<'a, A: Accept> { + fn get_config( + &'a mut self, hello: &'a ClientHello<'a>, metadata: &'a A::Metadata, - ) -> impl Future> + Send + 'a - where - Self: 'a, - A: 'a, - A::Metadata: 'a; + ) -> impl Future> + Send + 'a; } #[derive(Clone)] pub struct ChainedHandler(pub H0, pub H1); -impl TlsHandler for ChainedHandler +impl<'a, A, H0, H1> TlsHandler<'a, A> for ChainedHandler where - A: Accept, + A: Accept + 'a, ::Metadata: Send + Sync, - H0: TlsHandler + Send, - H1: TlsHandler + Send, + H0: TlsHandler<'a, A> + Send, + H1: TlsHandler<'a, A> + Send, { - async fn get_config<'a>( - self, + async fn get_config( + &'a mut self, hello: &'a ClientHello<'a>, metadata: &'a ::Metadata, - ) -> Option - where - Self: 'a, - { + ) -> Option { if let Some(config) = self.0.get_config(hello, metadata).await { return Some(config); } @@ -74,6 +70,7 @@ where } } +#[derive(Clone)] pub struct TlsHandlerWrapper { pub inner: I, pub wrapper: W, @@ -81,7 +78,7 @@ pub struct TlsHandlerWrapper { pub trait WrapTlsHandler { fn wrap<'a>( - self, + &'a mut self, prev: ServerConfig, hello: &'a ClientHello<'a>, metadata: &'a ::Metadata, @@ -90,23 +87,18 @@ pub trait WrapTlsHandler { Self: 'a; } -impl TlsHandler for TlsHandlerWrapper +impl<'a, A, I, W> TlsHandler<'a, A> for TlsHandlerWrapper where - A: Accept, + A: Accept + 'a, ::Metadata: Send + Sync, - I: TlsHandler + Send, + I: TlsHandler<'a, A> + Send, W: WrapTlsHandler + Send, { - async fn get_config<'a>( - self, + async fn get_config( + &'a mut self, hello: &'a ClientHello<'a>, metadata: &'a ::Metadata, - ) -> Option - where - Self: 'a, - A: 'a, - ::Metadata: 'a, - { + ) -> Option { let prev = self.inner.get_config(hello, metadata).await?; self.wrapper.wrap(prev, hello, metadata).await } @@ -120,13 +112,20 @@ impl ResolvesServerCert for SingleCertResolver { } } -pub struct TlsListener> { +pub struct TlsListener TlsHandler<'a, A>> { pub accept: A, pub tls_handler: H, - in_progress: - Vec, AcceptStream)>, Error>>>, + in_progress: Vec< + BoxFuture< + 'static, + ( + H, + Result, AcceptStream)>, Error>, + ), + >, + >, } -impl> TlsListener { +impl TlsHandler<'a, A>> TlsListener { pub fn new(accept: A, cert_handler: H) -> Self { Self { accept, @@ -137,93 +136,111 @@ impl> TlsListener { } impl Accept for TlsListener where - A: Accept, + A: Accept + 'static, A::Metadata: Send + 'static, - H: TlsHandler + Clone + Send + 'static, + for<'a> H: TlsHandler<'a, A> + Clone + Send + 'static, { type Metadata = TlsMetadata; fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll> { - if let Some((idx, res)) = self - .in_progress - .iter_mut() - .enumerate() - .find_map(|(idx, fut)| match fut.poll_unpin(cx) { - Poll::Ready(a) => Some((idx, a)), - Poll::Pending => None, - }) - { - self.in_progress.swap_remove(idx); - if let Some(res) = res.transpose() { - return Poll::Ready(res); - } - } - - if let Poll::Ready((metadata, stream)) = self.accept.poll_accept(cx)? { - let tls_handler = self.tls_handler.clone(); - self.in_progress.push( - async move { - let mut acceptor = - LazyConfigAcceptor::new(Acceptor::default(), BackTrackingIO::new(stream)); - let mut mid: tokio_rustls::StartHandshake> = - match (&mut acceptor).await { - Ok(a) => a, - Err(e) => { - let mut stream = acceptor.take_io().or_not_found("acceptor io")?; - let (_, buf) = stream.rewind(); - if std::str::from_utf8(buf) - .ok() - .and_then(|buf| { - buf.lines() - .map(|l| l.trim()) - .filter(|l| !l.is_empty()) - .next() - }) - .map_or(false, |buf| { - regex::Regex::new("[A-Z]+ (.+) HTTP/1") - .unwrap() - .is_match(buf) - }) - { - handle_http_on_https(stream).await.log_err(); - - return Ok(None); - } else { - return Err(e).with_kind(ErrorKind::Network); - } - } - }; - let hello = mid.client_hello(); - if let Some(cfg) = tls_handler.get_config(&hello, &metadata).await { - let metadata = TlsMetadata { - inner: metadata, - tls_info: TlsHandshakeInfo { - sni: hello.server_name().map(InternedString::intern), - alpn: hello - .alpn() - .into_iter() - .flatten() - .map(|a| MaybeUtf8String(a.to_vec())) - .collect(), - }, - }; - let buffered = mid.io.stop_buffering(); - mid.io - .write_all(&buffered) - .await - .with_kind(ErrorKind::Network)?; - return Ok(Some(( - metadata, - Box::pin(mid.into_stream(Arc::new(cfg)).await?) as AcceptStream, - ))); - } - - Ok(None) + loop { + if let Some((idx, (handler, res))) = + self.in_progress + .iter_mut() + .enumerate() + .find_map(|(idx, fut)| match fut.poll_unpin(cx) { + Poll::Ready(a) => Some((idx, a)), + Poll::Pending => None, + }) + { + drop(self.in_progress.swap_remove(idx)); + if let Some(res) = res.transpose() { + self.tls_handler = handler; + return Poll::Ready(res); } - .boxed(), - ); + continue; + } + + if let Poll::Ready((metadata, stream)) = self.accept.poll_accept(cx)? { + crate::dbg!("ACCEPTED"); + let mut tls_handler = self.tls_handler.clone(); + self.in_progress.push( + async move { + let res = async { + let mut acceptor = LazyConfigAcceptor::new( + Acceptor::default(), + BackTrackingIO::new(stream), + ); + let mut mid: tokio_rustls::StartHandshake< + BackTrackingIO, + > = match (&mut acceptor).await { + Ok(a) => a, + Err(e) => { + let mut stream = + acceptor.take_io().or_not_found("acceptor io")?; + let (_, buf) = stream.rewind(); + if std::str::from_utf8(buf) + .ok() + .and_then(|buf| { + buf.lines() + .map(|l| l.trim()) + .filter(|l| !l.is_empty()) + .next() + }) + .map_or(false, |buf| { + regex::Regex::new("[A-Z]+ (.+) HTTP/1") + .unwrap() + .is_match(buf) + }) + { + handle_http_on_https(stream).await.log_err(); + + return Ok(None); + } else { + return Err(e).with_kind(ErrorKind::Network); + } + } + }; + let hello = mid.client_hello(); + crate::dbg!("getting config"); + if let Some(cfg) = tls_handler.get_config(&hello, &metadata).await { + crate::dbg!("config gotten"); + let metadata = TlsMetadata { + inner: metadata, + tls_info: TlsHandshakeInfo { + sni: hello.server_name().map(InternedString::intern), + alpn: hello + .alpn() + .into_iter() + .flatten() + .map(|a| MaybeUtf8String(a.to_vec())) + .collect(), + }, + }; + let buffered = mid.io.stop_buffering(); + mid.io + .write_all(&buffered) + .await + .with_kind(ErrorKind::Network)?; + return Ok(Some(( + metadata, + Box::pin(mid.into_stream(Arc::new(cfg)).await?) as AcceptStream, + ))); + } + crate::dbg!("no config"); + + Ok(None) + } + .await; + (tls_handler, res) + } + .boxed(), + ); + continue; + } + break; } Poll::Pending @@ -280,3 +297,20 @@ async fn handle_http_on_https(stream: impl ReadWriter + Unpin + 'static) -> Resu .await .map_err(|e| Error::new(color_eyre::eyre::Report::msg(e), ErrorKind::Network)) } + +pub fn client_config<'a, I: IntoIterator>( + crypto_provider: Arc, + root_certs: I, +) -> Result { + let mut certs = RootCertStore::empty(); + for cert in root_certs { + certs + .add(CertificateDer::from_slice(&cert.to_der()?)) + .with_kind(ErrorKind::OpenSsl)?; + } + Ok(ClientConfig::builder_with_provider(crypto_provider.clone()) + .with_safe_default_protocol_versions() + .with_kind(ErrorKind::OpenSsl)? + .with_root_certificates(certs) + .with_no_client_auth()) +} diff --git a/core/startos/src/net/tunnel.rs b/core/startos/src/net/tunnel.rs index 7702f7544..a31ca5a8b 100644 --- a/core/startos/src/net/tunnel.rs +++ b/core/startos/src/net/tunnel.rs @@ -125,7 +125,7 @@ pub async fn remove_tunnel( return Ok(()); }; - if existing.as_device_type().de()? != Some(NetworkInterfaceType::Wireguard) { + if existing.as_deref().as_device_type().de()? != Some(NetworkInterfaceType::Wireguard) { return Err(Error::new( eyre!("network interface {id} is not a proxy"), ErrorKind::InvalidRequest, diff --git a/core/startos/src/net/utils.rs b/core/startos/src/net/utils.rs index da4295eea..18c2c3705 100644 --- a/core/startos/src/net/utils.rs +++ b/core/startos/src/net/utils.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV6}; use std::path::Path; @@ -5,7 +6,6 @@ use async_stream::try_stream; use color_eyre::eyre::eyre; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; -use imbl::OrdMap; use imbl_value::InternedString; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use models::GatewayId; @@ -13,13 +13,11 @@ use nix::net::if_::if_nametoindex; use tokio::net::{TcpListener, TcpStream}; use tokio::process::Command; -use crate::db::model::public::{NetworkInterfaceInfo, NetworkInterfaceType}; +use crate::db::model::public::{IpInfo, NetworkInterfaceType}; use crate::prelude::*; -use crate::util::collections::OrdMapIterMut; use crate::util::Invoke; -pub async fn load_network_interface_info() -> Result, Error> -{ +pub async fn load_ip_info() -> Result, Error> { let output = String::from_utf8( Command::new("ip") .arg("-o") @@ -36,14 +34,14 @@ pub async fn load_network_interface_info() -> Result::new(); + let mut res = BTreeMap::::new(); for line in output.lines() { let split = line.split_ascii_whitespace().collect::>(); let iface = GatewayId::from(InternedString::from(*split.get(1).ok_or_else(&err_fn)?)); let subnet: IpNet = split.get(3).ok_or_else(&err_fn)?.parse()?; - let info = res.entry(iface).or_default(); - let ip_info = info.ip_info.get_or_insert_default(); + let ip_info = res.entry(iface.clone()).or_default(); + ip_info.name = iface.into(); ip_info.scope_id = split .get(0) .ok_or_else(&err_fn)? @@ -53,8 +51,7 @@ pub async fn load_network_interface_info() -> Result() -> ParentHandler { } let mut table = Table::new(); - table.add_row(row![bc => "FROM", "TO", "GATEWAYS", "CONNECT SSL", "ACTIVE"]); + table.add_row(row![bc => "FROM", "TO", "ACTIVE"]); for (external, targets) in res { for (host, targets) in targets { @@ -74,9 +69,7 @@ pub fn vhost_api() -> ParentHandler { host.as_ref().map(|s| &**s).unwrap_or("*"), external.0 ), - target.addr, - target.gateways.iter().join(", "), - target.connect_ssl.is_ok(), + target, idx == 0 ]); } @@ -101,11 +94,15 @@ pub struct VHostController { servers: SyncMutex>>, } impl VHostController { - pub fn new(db: TypedPatchDb, interfaces: Arc) -> Self { + pub fn new( + db: TypedPatchDb, + interfaces: Arc, + crypto_provider: Arc, + ) -> Self { Self { db, interfaces, - crypto_provider: Arc::new(tokio_rustls::rustls::crypto::ring::default_provider()), + crypto_provider, acme_cache: Arc::new(SyncMutex::new(BTreeMap::new())), servers: SyncMutex::new(BTreeMap::new()), } @@ -115,7 +112,7 @@ impl VHostController { &self, hostname: Option, external: u16, - target: impl VHostTarget, + target: DynVHostTarget, ) -> Result, Error> { self.servers.mutate(|writable| { let server = if let Some(server) = writable.remove(&external) { @@ -136,27 +133,26 @@ impl VHostController { pub fn dump_table( &self, - ) -> BTreeMap, BTreeMap>, EqSet>> - { + ) -> BTreeMap, BTreeMap>, EqSet>> { let ip_info = self.interfaces.watcher.ip_info(); self.servers.peek(|s| { s.iter() .map(|(k, v)| { ( JsonKey::new(*k), - v.mapping - .borrow() - .iter() - .map(|(k, v)| { - ( - JsonKey::new(k.clone()), - v.iter() - .filter(|(_, v)| v.strong_count() > 0) - .map(|(k, _)| ShowTargetInfo::new(k.clone(), &ip_info)) - .collect(), - ) - }) - .collect(), + v.mapping.peek(|m| { + m.iter() + .map(|(k, v)| { + ( + JsonKey::new(k.clone()), + v.iter() + .filter(|(_, v)| v.strong_count() > 0) + .map(|(k, _)| format!("{k:?}")) + .collect(), + ) + }) + .collect() + }), ) }) .collect() @@ -179,8 +175,11 @@ impl VHostController { pub trait VHostTarget: std::fmt::Debug + Eq { type PreprocessRes: Send + 'static; #[allow(unused_variables)] - fn skip(&self, metadata: &::Metadata) -> bool { - false + fn filter(&self, metadata: &::Metadata) -> bool { + true + } + fn acme(&self) -> Option<&AcmeProvider> { + None } fn preprocess<'a>( &'a self, @@ -192,7 +191,8 @@ pub trait VHostTarget: std::fmt::Debug + Eq { } pub trait DynVHostTargetT: std::fmt::Debug + Any { - fn skip(&self, metadata: &::Metadata) -> bool; + fn filter(&self, metadata: &::Metadata) -> bool; + fn acme(&self) -> Option<&AcmeProvider>; fn preprocess<'a>( &'a self, prev: ServerConfig, @@ -203,8 +203,11 @@ pub trait DynVHostTargetT: std::fmt::Debug + Any { fn eq(&self, other: &dyn DynVHostTargetT) -> bool; } impl + 'static> DynVHostTargetT for T { - fn skip(&self, metadata: &::Metadata) -> bool { - VHostTarget::skip(self, metadata) + fn filter(&self, metadata: &::Metadata) -> bool { + VHostTarget::filter(self, metadata) + } + fn acme(&self) -> Option<&AcmeProvider> { + VHostTarget::acme(self) } fn preprocess<'a>( &'a self, @@ -226,12 +229,22 @@ impl + 'static> DynVHostTargetT for T { } } -struct DynVHostTarget(Arc + Send + Sync>); +pub struct DynVHostTarget(Arc + Send + Sync>); +impl DynVHostTarget { + pub fn new + Send + Sync + 'static>(target: T) -> Self { + Self(Arc::new(target)) + } +} impl Clone for DynVHostTarget { fn clone(&self) -> Self { Self(self.0.clone()) } } +impl std::fmt::Debug for DynVHostTarget { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} impl PartialEq for DynVHostTarget { fn eq(&self, other: &Self) -> bool { self.0.eq(&*other.0) @@ -259,6 +272,7 @@ impl Preprocessed { #[derive(Debug, Clone)] pub struct ProxyTarget { pub filter: DynInterfaceFilter, + pub acme: Option, pub addr: SocketAddr, pub connect_ssl: Result, AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn } @@ -266,7 +280,8 @@ impl PartialEq for ProxyTarget { fn eq(&self, other: &Self) -> bool { self.filter == other.filter && self.addr == other.addr - && self.connect_ssl.as_ref().err() == other.connect_ssl.as_ref().err() + && self.connect_ssl.as_ref().map(Arc::as_ptr) + == other.connect_ssl.as_ref().map(Arc::as_ptr) } } impl Eq for ProxyTarget {} @@ -274,11 +289,16 @@ impl Eq for ProxyTarget {} impl VHostTarget for ProxyTarget where A: Accept + 'static, - ::Metadata: Send + Sync, + ::Metadata: Visit> + Clone + Send + Sync, { type PreprocessRes = AcceptStream; - fn skip(&self, metadata: &::Metadata) -> bool { - let info = extract::(metadata) + fn filter(&self, metadata: &::Metadata) -> bool { + let info = extract::(metadata); + info.as_ref() + .map_or(true, |i| self.filter.filter(&i.id, &i.info)) + } + fn acme(&self) -> Option<&AcmeProvider> { + self.acme.as_ref() } async fn preprocess<'a>( &'a self, @@ -345,15 +365,61 @@ impl Default for AlpnInfo { type Mapping = BTreeMap, InOMap, Weak<()>>>; -pub struct VHostConnector<'a, A: Accept + 'static>(&'a Watch>, Option>); +pub struct GetVHostAcmeProvider(pub Watch>); +impl Clone for GetVHostAcmeProvider { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} +impl GetAcmeProvider for GetVHostAcmeProvider { + async fn get_provider<'a, 'b: 'a>( + &'b self, + san_info: &'a BTreeSet, + ) -> Option + Send + 'b> { + self.0.peek(|m| -> Option { + san_info + .iter() + .fold(Some::>(None), |acc, x| { + let acc = acc?; + if x.parse::().is_ok() { + return Some(acc); + } + let (t, _) = m + .get(&Some(x.clone()))? + .iter() + .find(|(_, rc)| rc.strong_count() > 0)?; + let acme = t.0.acme()?; + Some(if let Some(acc) = acc { + if acme == acc { + // all must match + Some(acme) + } else { + None + } + } else { + Some(acme) + }) + }) + .flatten() + .cloned() + }) + } +} -impl<'b, A> WrapTlsHandler for &'b mut VHostConnector<'b, A> +pub struct VHostConnector(Watch>, Option>); +impl Clone for VHostConnector { + fn clone(&self) -> Self { + Self(self.0.clone(), None) + } +} + +impl WrapTlsHandler for VHostConnector where A: Accept + 'static, - ::Metadata: Send + Sync, + ::Metadata: Visit> + Send + Sync, { async fn wrap<'a>( - self, + &'a mut self, prev: ServerConfig, hello: &'a ClientHello<'a>, metadata: &'a ::Metadata, @@ -375,7 +441,7 @@ where .into_iter() .flatten() .filter(|(_, rc)| rc.strong_count() > 0) - .find(|(t, _)| !t.skip(metadata)) + .find(|(t, _)| t.0.filter(metadata)) .map(|(e, _)| e.clone()) })?; @@ -387,6 +453,97 @@ where } } +struct VHostListener( + TlsListener< + A, + TlsHandlerWrapper< + ChainedHandler>>, RootCaTlsHandler>, + VHostConnector, + >, + >, +) +where + for<'a> M: HasModel> + + DbAccessMut + + DbAccessMut + + DbAccessByKey = &'a AcmeProvider> + + Send + + Sync, + A: Accept + 'static, + ::Metadata: Visit> + + Visit> + + Clone + + Send + + Sync + + 'static; +struct VHostListenerMetadata { + inner: TlsMetadata, + preprocessed: Preprocessed, +} +impl Accept for VHostListener +where + for<'a> M: HasModel> + + DbAccessMut + + DbAccessMut + + DbAccessByKey = &'a AcmeProvider> + + Send + + Sync + + 'static, + A: Accept + 'static, + ::Metadata: Visit> + + Visit> + + Clone + + Send + + Sync + + 'static, +{ + type Metadata = VHostListenerMetadata; + fn poll_accept( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let (metadata, stream) = ready!(self.0.poll_accept(cx)?); + let preprocessed = self.0.tls_handler.wrapper.1.take(); + Poll::Ready(Ok(( + VHostListenerMetadata { + inner: metadata, + preprocessed: preprocessed.ok_or_else(|| { + Error::new( + eyre!("tlslistener yielded but preprocessed isn't set"), + ErrorKind::Incoherent, + ) + })?, + }, + stream, + ))) + } +} +impl VHostListener +where + for<'a> M: HasModel> + + DbAccessMut + + DbAccessMut + + DbAccessByKey = &'a AcmeProvider> + + Send + + Sync + + 'static, + A: Accept + 'static, + ::Metadata: Visit> + + Visit> + + Clone + + Send + + Sync + + 'static, +{ + async fn handle_next(&mut self) -> Result<(), Error> { + let (metadata, stream) = futures::future::poll_fn(|cx| self.poll_accept(cx)).await?; + + metadata.preprocessed.finish(stream); + + Ok(()) + } +} + struct VHostServer { mapping: Watch>, _thread: NonDetachingJoinHandle<()>, @@ -407,472 +564,69 @@ impl<'a> From<&'a BTreeMap, BTreeMap, - db: TypedPatchDb, - acme_tls_alpn_cache: AcmeTlsAlpnCache, - crypto_provider: Arc, - ) -> Result<(), Error> { - let accepted; - - loop { - let any_filter = AnyFilter::from(&*mapping.borrow()); - - let changed_filter = mapping - .wait_for(|m| any_filter != AnyFilter::from(m)) - .boxed(); - - tokio::select! { - a = listener.accept(&any_filter) => { - accepted = a?; - break; - } - _ = changed_filter => { - tracing::debug!("port {} filter changed", listener.port()); - } - } - } - - let check = listener.check_filter(); - tokio::spawn(async move { - let bind = accepted.bind; - if let Err(e) = Self::handle_stream( - accepted, - check, - mapping, - db, - acme_tls_alpn_cache, - crypto_provider, - ) - .await - { - tracing::error!("Error in VHostController on {bind}: {e}"); - tracing::debug!("{e:?}") - } - }); - Ok(()) - } - - async fn handle_stream( - Accepted { - stream, - wan_ip, - bind, - .. - }: Accepted, - check_filter: impl FnOnce(SocketAddr, &DynInterfaceFilter) -> bool, - mapping: watch::Receiver, - db: TypedPatchDb, - acme_tls_alpn_cache: AcmeTlsAlpnCache, - crypto_provider: Arc, - ) -> Result<(), Error> { - let mut stream = BackTrackingIO::new(stream); - let mid: tokio_rustls::StartHandshake<&mut BackTrackingIO> = - match LazyConfigAcceptor::new(Acceptor::default(), &mut stream).await { - Ok(a) => a, - Err(e) => { - let (_, buf) = stream.rewind(); - if std::str::from_utf8(buf) - .ok() - .and_then(|buf| { - buf.lines() - .map(|l| l.trim()) - .filter(|l| !l.is_empty()) - .next() - }) - .map_or(false, |buf| { - regex::Regex::new("[A-Z]+ (.+) HTTP/1") - .unwrap() - .is_match(buf) - }) - { - return hyper_util::server::conn::auto::Builder::new( - hyper_util::rt::TokioExecutor::new(), - ) - .serve_connection( - hyper_util::rt::TokioIo::new(stream), - hyper_util::service::TowerToHyperService::new( - axum::Router::new().fallback(axum::routing::method_routing::any( - move |req: Request| async move { - match async move { - let host = req - .headers() - .get(http::header::HOST) - .and_then(|host| host.to_str().ok()); - if let Some(host) = host { - let uri = Uri::from_parts({ - let mut parts = - req.uri().to_owned().into_parts(); - parts.scheme = Some("https".parse()?); - parts.authority = Some(host.parse()?); - parts - })?; - Response::builder() - .status(http::StatusCode::TEMPORARY_REDIRECT) - .header(http::header::LOCATION, uri.to_string()) - .body(Body::default()) - } else { - Response::builder() - .status(http::StatusCode::BAD_REQUEST) - .body(Body::from("Host header required")) - } - } - .await - { - Ok(a) => a, - Err(e) => { - tracing::warn!( - "Error redirecting http request on ssl port: {e}" - ); - tracing::error!("{e:?}"); - server_error(Error::new(e, ErrorKind::Network)) - } - } - }, - )), - ), - ) - .await - .map_err(|e| { - Error::new(color_eyre::eyre::Report::msg(e), ErrorKind::Network) - }); - } else { - return Err(e).with_kind(ErrorKind::Network); - } - } - }; - let target_name: Option = - mid.client_hello().server_name().map(|s| s.into()); - if let Some(domain) = target_name.as_ref() { - if mid - .client_hello() - .alpn() - .into_iter() - .flatten() - .any(|alpn| alpn == ACME_TLS_ALPN_NAME) - { - let cert = WatchStream::new( - acme_tls_alpn_cache - .peek(|c| c.get(&**domain).cloned()) - .ok_or_else(|| { - Error::new( - eyre!("No challenge recv available for {domain}"), - ErrorKind::OpenSsl, - ) - })?, - ); - tracing::info!("Waiting for verification cert for {domain}"); - let cert = cert - .filter(|c| c.is_some()) - .next() - .await - .flatten() - .ok_or_else(|| { - Error::new( - eyre!("No challenge available for {domain}"), - ErrorKind::OpenSsl, - ) - })?; - tracing::info!("Verification cert received for {domain}"); - let mut cfg = ServerConfig::builder_with_provider(crypto_provider.clone()) - .with_safe_default_protocol_versions() - .with_kind(crate::ErrorKind::OpenSsl)? - .with_no_client_auth() - .with_cert_resolver(Arc::new(SingleCertResolver(cert))); - - cfg.alpn_protocols = vec![ACME_TLS_ALPN_NAME.to_vec()]; - tracing::info!("performing ACME auth challenge"); - let mut accept = mid.into_stream(Arc::new(cfg)); - let io = accept.get_mut().unwrap(); - let buffered = io.stop_buffering(); - io.write_all(&buffered).await?; - accept.await?; - tracing::info!("ACME auth challenge completed"); - return Ok(()); - } - } - let target = { - let m = mapping.borrow(); - m.get(&target_name) - .into_iter() - .flatten() - .find(|(_, rc)| rc.strong_count() > 0) - .or_else(|| { - if target_name - .as_ref() - .map(|s| s.parse::().is_ok()) - .unwrap_or(true) - { - m.get(&None) - .into_iter() - .flatten() - .find(|(_, rc)| rc.strong_count() > 0) - } else { - None - } - }) - .map(|(target, _)| target.clone()) - }; - if let Some(target) = target { - if !check_filter(bind, &target.filter) { - log::warn!("Connection from {bind} to {target:?} rejected by filter"); - return Ok(()); - } - let peek = db.peek().await; - let root = peek - .as_private() - .as_key_store() - .as_local_certs() - .as_root_cert() - .de()?; - let mut cfg = async { - if let Some((domain, provider, settings)) = - target_name.as_ref().and_then(|domain| { - target.acme.as_ref().and_then(|a| { - peek.as_public() - .as_server_info() - .as_network() - .as_acme() - .as_idx(a) - .map(|s| (domain, a, s)) - }) - }) - { - let acme_settings = settings.de()?; - let mut identifiers = vec![Identifier::Dns(domain.to_string())]; - if false - // Requires RFC 8738 - { - if let Some(wan_ip) = wan_ip { - identifiers.push(Identifier::Ip(wan_ip.into())); - } - } - let (send, recv) = watch::channel(None); - acme_tls_alpn_cache.mutate(|c| c.insert(domain.clone(), recv)); - let cert = async_acme::rustls_helper::order( - |_, cert| { - send.send_replace(Some(Arc::new(cert))); - Ok(()) - }, - provider.0.as_str(), - &identifiers, - Some(&AcmeCertCache(&db)), - &acme_settings.contact, - ) - .await - .with_kind(ErrorKind::OpenSsl)?; - return Ok(ServerConfig::builder_with_provider(crypto_provider.clone()) - .with_safe_default_protocol_versions() - .with_kind(crate::ErrorKind::OpenSsl)? - .with_no_client_auth() - .with_cert_resolver(Arc::new(SingleCertResolver(Arc::new(cert))))); - } - - let hostnames = target_name - .into_iter() - .chain([InternedString::from_display(&bind.ip())]) - .chain(wan_ip.as_ref().map(InternedString::from_display)) - .collect(); - let key = db - .mutate(|v| { - v.as_private_mut() - .as_key_store_mut() - .as_local_certs_mut() - .cert_for(&hostnames) - }) - .await - .result?; - let cfg = ServerConfig::builder_with_provider(crypto_provider.clone()) - .with_safe_default_protocol_versions() - .with_kind(crate::ErrorKind::OpenSsl)? - .with_no_client_auth(); - if mid - .client_hello() - .signature_schemes() - .contains(&tokio_rustls::rustls::SignatureScheme::ED25519) - { - cfg.with_single_cert( - key.fullchain_ed25519() - .into_iter() - .map(|c| { - Ok(tokio_rustls::rustls::pki_types::CertificateDer::from( - c.to_der()?, - )) - }) - .collect::>()?, - PrivateKeyDer::from(PrivatePkcs8KeyDer::from( - key.leaf.keys.ed25519.private_key_to_pkcs8()?, - )), - ) - } else { - cfg.with_single_cert( - key.fullchain_nistp256() - .into_iter() - .map(|c| { - Ok(tokio_rustls::rustls::pki_types::CertificateDer::from( - c.to_der()?, - )) - }) - .collect::>()?, - PrivateKeyDer::from(PrivatePkcs8KeyDer::from( - key.leaf.keys.nistp256.private_key_to_pkcs8()?, - )), - ) - } - .with_kind(crate::ErrorKind::OpenSsl) - } - .await?; - let mut tcp_stream = TcpStream::connect(target.addr).await?; - match target.connect_ssl { - Ok(()) => { - let mut client_cfg = - tokio_rustls::rustls::ClientConfig::builder_with_provider(crypto_provider) - .with_safe_default_protocol_versions() - .with_kind(crate::ErrorKind::OpenSsl)? - .with_root_certificates({ - let mut store = RootCertStore::empty(); - store - .add(CertificateDer::from(root.to_der()?)) - .with_kind(crate::ErrorKind::OpenSsl)?; - store - }) - .with_no_client_auth(); - client_cfg.alpn_protocols = mid - .client_hello() - .alpn() - .into_iter() - .flatten() - .map(|x| x.to_vec()) - .collect(); - let mut target_stream = TlsConnector::from(Arc::new(client_cfg)) - .connect_with( - ServerName::IpAddress(target.addr.ip().into()), - tcp_stream, - |conn| { - cfg.alpn_protocols - .extend(conn.alpn_protocol().into_iter().map(|p| p.to_vec())) - }, - ) - .await - .with_kind(crate::ErrorKind::OpenSsl)?; - let mut accept = mid.into_stream(Arc::new(cfg)); - let io = accept.get_mut().unwrap(); - let buffered = io.stop_buffering(); - io.write_all(&buffered).await?; - let mut tls_stream = match accept.await { - Ok(a) => a, - Err(e) => { - tracing::trace!( - "VHostController: failed to accept TLS connection on {bind}: {e}" - ); - tracing::trace!("{e:?}"); - return Ok(()); - } - }; - tokio::io::copy_bidirectional(&mut tls_stream, &mut target_stream).await - } - Err(AlpnInfo::Reflect) => { - for proto in mid.client_hello().alpn().into_iter().flatten() { - cfg.alpn_protocols.push(proto.into()); - } - let mut accept = mid.into_stream(Arc::new(cfg)); - let io = accept.get_mut().unwrap(); - let buffered = io.stop_buffering(); - io.write_all(&buffered).await?; - let mut tls_stream = match accept.await { - Ok(a) => a, - Err(e) => { - tracing::trace!( - "VHostController: failed to accept TLS connection on {bind}: {e}" - ); - tracing::trace!("{e:?}"); - return Ok(()); - } - }; - tokio::io::copy_bidirectional(&mut tls_stream, &mut tcp_stream).await - } - Err(AlpnInfo::Specified(alpn)) => { - cfg.alpn_protocols = alpn.into_iter().map(|a| a.0).collect(); - let mut accept = mid.into_stream(Arc::new(cfg)); - let io = accept.get_mut().unwrap(); - let buffered = io.stop_buffering(); - io.write_all(&buffered).await?; - let mut tls_stream = match accept.await { - Ok(a) => a, - Err(e) => { - tracing::trace!( - "VHostController: failed to accept TLS connection on {bind}: {e}" - ); - tracing::trace!("{e:?}"); - return Ok(()); - } - }; - tokio::io::copy_bidirectional(&mut tls_stream, &mut tcp_stream).await - } - } - .map_or_else( - |e| { - use std::io::ErrorKind as E; - match e.kind() { - E::UnexpectedEof - | E::BrokenPipe - | E::ConnectionAborted - | E::ConnectionReset - | E::ConnectionRefused - | E::TimedOut - | E::Interrupted - | E::NotConnected => Ok(()), - _ => Err(e), - } - }, - |_| Ok(()), - )?; - } else { - // 503 - } - Ok::<_, Error>(()) - } - +impl VHostServer { #[instrument(skip_all)] - fn new( + fn new( listener: A, - db: TypedPatchDb, + db: TypedPatchDb, crypto_provider: Arc, acme_cache: AcmeTlsAlpnCache, - ) -> Result { + ) -> Self + where + for<'a> M: HasModel> + + DbAccessMut + + DbAccessMut + + DbAccessByKey = &'a AcmeProvider> + + Send + + Sync + + 'static, + A: Accept + Send + 'static, + ::Metadata: Visit> + + Visit> + + Clone + + Send + + Sync + + 'static, + { let mapping = Watch::new(BTreeMap::new()); - Ok(Self { + Self { mapping: mapping.clone(), _thread: tokio::spawn(async move { - let listener = TlsListener::new( + let mut listener = VHostListener(TlsListener::new( listener, - VHostTlsHandler { - cert_handler: ChainedHandler( - &AcmeTlsHandler { - db: &db, - acme_cache: &acme_cache, - crypto_provider: &crypto_provider, - get_provider: todo!(), + TlsHandlerWrapper { + inner: ChainedHandler( + Arc::new(AcmeTlsHandler { + db: db.clone(), + acme_cache, + crypto_provider: crypto_provider.clone(), + get_provider: GetVHostAcmeProvider(mapping.clone()), in_progress: Watch::new(BTreeSet::new()), + }), + RootCaTlsHandler { + db, + crypto_provider, }, - todo!(), ), - alpn_handler: todo!(), + wrapper: VHostConnector(mapping, None), }, - ); + )); loop { - if let Err(e) = Self::accept(&mut listener, &mapping).await { - tracing::error!("VHostController: failed to accept connection: {e}"); + if let Err(e) = listener.handle_next().await { + tracing::error!("VHostServer: failed to accept connection: {e}"); tracing::debug!("{e:?}"); } } }) .into(), - }) + } } - fn add(&self, hostname: Option, target: ProxyTarget) -> Result, Error> { + fn add( + &self, + hostname: Option, + target: DynVHostTarget, + ) -> Result, Error> { + let target = target.into(); let mut res = Ok(Arc::new(())); self.mapping.send_if_modified(|writable| { let mut changed = false; @@ -888,7 +642,7 @@ impl VHostServer { res = Ok(rc); changed }); - if !self.mapping.is_closed() { + if self.mapping.watcher_count() > 1 { res } else { Err(Error::new( @@ -913,6 +667,6 @@ impl VHostServer { }); } fn is_empty(&self) -> bool { - self.mapping.borrow().is_empty() + self.mapping.peek(|m| m.is_empty()) } } diff --git a/core/startos/src/net/web_server.rs b/core/startos/src/net/web_server.rs index 5a5ef465c..13f67a436 100644 --- a/core/startos/src/net/web_server.rs +++ b/core/startos/src/net/web_server.rs @@ -1,21 +1,22 @@ use std::any::Any; +use std::collections::BTreeMap; use std::future::Future; use std::net::SocketAddr; use std::ops::Deref; use std::pin::Pin; use std::sync::Arc; -use std::task::Poll; +use std::task::{ready, Poll}; use std::time::Duration; use axum::Router; use futures::future::Either; -use futures::FutureExt; +use futures::{FutureExt, TryFutureExt}; use helpers::NonDetachingJoinHandle; use http::Extensions; use hyper_util::rt::{TokioIo, TokioTimer}; use tokio::net::TcpListener; use tokio::sync::oneshot; -use visit_rs::{Visit, Visitor}; +use visit_rs::{Visit, VisitFields, Visitor}; use crate::net::static_server::{ui_router, UiContext}; use crate::prelude::*; @@ -38,6 +39,16 @@ impl<'a> MetadataVisitor for ExtensionVisitor<'a> { self.0.insert(metadata.clone()); } } +impl<'a> Visit> + for Box Visit> + Send + Sync + 'static> +{ + fn visit( + &self, + visitor: &mut ExtensionVisitor<'a>, + ) -> as Visitor>::Result { + (&**self).visit(visitor) + } +} pub struct ExtractVisitor(Option); impl Visitor for ExtractVisitor { @@ -50,7 +61,12 @@ impl MetadataVisitor for ExtractVisitor { } } } -pub fn extract>>(metadata: &M) -> Option { +pub fn extract< + T: Clone + Send + Sync + 'static, + M: Visit> + Clone + Send + Sync + 'static, +>( + metadata: &M, +) -> Option { let mut visitor = ExtractVisitor(None); visitor.visit(metadata); visitor.0 @@ -73,6 +89,13 @@ pub trait Accept { &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>; + fn into_dyn(self) -> DynAccept + where + Self: Sized + Send + Sync + 'static, + for<'a> Self::Metadata: Visit> + Send + Sync + 'static, + { + DynAccept::new(self) + } } impl Accept for TcpListener { @@ -121,6 +144,47 @@ where } } +#[derive(Clone, VisitFields)] +pub struct MapListenerMetadata { + pub inner: M, + pub key: K, +} +impl Visit for MapListenerMetadata +where + V: MetadataVisitor, + K: Visit, + M: Visit, +{ + fn visit(&self, visitor: &mut V) -> ::Result { + self.visit_fields(visitor).collect() + } +} + +impl Accept for BTreeMap +where + K: Clone, + A: Accept, +{ + type Metadata = MapListenerMetadata; + fn poll_accept( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + for (key, listener) in self { + if let Poll::Ready((metadata, stream)) = listener.poll_accept(cx)? { + return Poll::Ready(Ok(( + MapListenerMetadata { + inner: metadata, + key: key.clone(), + }, + stream, + ))); + } + } + Poll::Pending + } +} + impl Accept for Either where A: Accept, @@ -150,6 +214,68 @@ impl Accept for Option { } } +trait DynAcceptT: Send + Sync { + fn poll_accept( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll< + Result< + ( + Box Visit> + Send + Sync>, + AcceptStream, + ), + Error, + >, + >; +} +impl DynAcceptT for A +where + A: Accept + Send + Sync, + for<'a> ::Metadata: Visit> + Send + Sync + 'static, +{ + fn poll_accept( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll< + Result< + ( + Box Visit> + Send + Sync>, + AcceptStream, + ), + Error, + >, + > { + let (metadata, stream) = ready!(Accept::poll_accept(self, cx)?); + Poll::Ready(Ok((Box::new(metadata), stream))) + } +} +pub struct DynAccept(Box); +impl Accept for DynAccept { + type Metadata = Box Visit> + Send + Sync>; + fn poll_accept( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + DynAcceptT::poll_accept(&mut *self.0, cx) + } + fn into_dyn(self) -> DynAccept + where + Self: Sized, + for<'a> Self::Metadata: Visit> + Send + Sync + 'static, + { + self + } +} +impl DynAccept { + pub fn new(accept: A) -> Self + where + A: Accept + Send + Sync + 'static, + for<'a> ::Metadata: Visit> + Send + Sync + 'static, + { + Self(Box::new(accept)) + } +} + #[pin_project::pin_project] pub struct Acceptor { acceptor: Watch, @@ -184,6 +310,64 @@ impl Acceptor> { )) } } +impl Acceptor> { + pub async fn bind_dyn(listen: impl IntoIterator) -> Result { + Ok(Self::new( + futures::future::try_join_all( + listen + .into_iter() + .map(TcpListener::bind) + .map(|f| f.map_ok(DynAccept::new)), + ) + .await?, + )) + } +} +impl Acceptor> +where + K: Ord + Clone + Send + Sync + 'static, +{ + pub async fn bind_map( + listen: impl IntoIterator, + ) -> Result { + Ok(Self::new( + futures::future::try_join_all(listen.into_iter().map(|(key, addr)| async move { + Ok::<_, Error>(( + key, + TcpListener::bind(addr) + .await + .with_kind(ErrorKind::Network)?, + )) + })) + .await? + .into_iter() + .collect(), + )) + } +} +impl Acceptor> +where + K: Ord + Clone + Send + Sync + 'static, +{ + pub async fn bind_map_dyn( + listen: impl IntoIterator, + ) -> Result { + Ok(Self::new( + futures::future::try_join_all(listen.into_iter().map(|(key, addr)| async move { + Ok::<_, Error>(( + key, + TcpListener::bind(addr) + .await + .with_kind(ErrorKind::Network)?, + )) + })) + .await? + .into_iter() + .map(|(key, listener)| (key, listener.into_dyn())) + .collect(), + )) + } +} pub struct WebServerAcceptorSetter { acceptor: Watch, diff --git a/core/startos/src/registry/mod.rs b/core/startos/src/registry/mod.rs index 811188d4e..e9aacd168 100644 --- a/core/startos/src/registry/mod.rs +++ b/core/startos/src/registry/mod.rs @@ -3,7 +3,7 @@ use std::collections::{BTreeMap, BTreeSet}; use axum::Router; use futures::future::ready; use models::DataUrl; -use rpc_toolkit::{Context, HandlerExt, ParentHandler, Server, from_fn_async}; +use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler, Server}; use serde::{Deserialize, Serialize}; use ts_rs::TS; @@ -141,9 +141,3 @@ pub fn registry_router(ctx: RegistryContext) -> Router { }), ) } - -impl WebServer { - pub fn serve_registry(&mut self, ctx: RegistryContext) { - self.serve_router(registry_router(ctx)) - } -} diff --git a/core/startos/src/service/effects/net/ssl.rs b/core/startos/src/service/effects/net/ssl.rs index be8a5d7d7..a39c6a35a 100644 --- a/core/startos/src/service/effects/net/ssl.rs +++ b/core/startos/src/service/effects/net/ssl.rs @@ -6,11 +6,11 @@ use ipnet::IpNet; use itertools::Itertools; use openssl::pkey::{PKey, Private}; -use crate::HOST_IP; use crate::service::effects::callbacks::CallbackHandler; use crate::service::effects::prelude::*; use crate::service::rpc::CallbackId; use crate::util::serde::Pem; +use crate::HOST_IP; #[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, TS, PartialEq, Eq)] #[serde(rename_all = "camelCase")] @@ -95,7 +95,7 @@ pub async fn get_ssl_certificate( .as_entries()? .into_iter() .flat_map(|(_, net)| net.as_ip_info().transpose_ref()) - .flat_map(|net| net.as_subnets().de().log_err()) + .flat_map(|net| net.as_deref().as_subnets().de().log_err()) .flatten() .any(|s| s.addr() == ip) { diff --git a/core/startos/src/tunnel/context.rs b/core/startos/src/tunnel/context.rs index c03e6d9f9..dbe270c10 100644 --- a/core/startos/src/tunnel/context.rs +++ b/core/startos/src/tunnel/context.rs @@ -1,12 +1,11 @@ -use std::collections::{BTreeMap, BTreeSet}; -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::collections::BTreeMap; +use std::net::{IpAddr, SocketAddr}; use std::ops::Deref; use std::path::{Path, PathBuf}; use std::sync::Arc; use clap::Parser; use cookie::{Cookie, Expiration, SameSite}; -use helpers::NonDetachingJoinHandle; use http::HeaderMap; use imbl::OrdMap; use imbl_value::InternedString; @@ -23,14 +22,16 @@ use url::Url; use crate::auth::Sessions; use crate::context::config::ContextConfig; use crate::context::{CliContext, RpcContext}; -use crate::db::model::public::{NetworkInterfaceInfo, NetworkInterfaceType}; +use crate::db::model::public::NetworkInterfaceInfo; use crate::middleware::auth::AuthContext; use crate::net::forward::PortForwardController; -use crate::net::gateway::{IdFilter, InterfaceFilter, NetworkInterfaceWatcher}; +use crate::net::gateway::{IdFilter, InterfaceFilter}; use crate::prelude::*; use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations}; use crate::tunnel::db::{GatewayPort, TunnelDatabase}; +use crate::tunnel::wg::WIREGUARD_INTERFACE_NAME; use crate::tunnel::TUNNEL_DEFAULT_LISTEN; +use crate::util::collections::OrdMapIterMut; use crate::util::io::read_file_to_string; use crate::util::sync::{SyncMutex, Watch}; use crate::util::Invoke; @@ -100,7 +101,24 @@ impl TunnelContext { ) .await?; let listen = config.tunnel_listen.unwrap_or(TUNNEL_DEFAULT_LISTEN); - let net_iface = Watch::new(crate::net::utils::load_network_interface_info().await?); + let ip_info = crate::net::utils::load_ip_info().await?; + let net_iface = db + .mutate(|db| { + db.as_gateways_mut().mutate(|g| { + for (_, v) in OrdMapIterMut::from(&mut *g) { + v.ip_info = None; + } + for (id, info) in ip_info { + if id.as_str() != WIREGUARD_INTERFACE_NAME { + g.entry(id).or_default().ip_info = Some(Arc::new(info)); + } + } + Ok(g.clone()) + }) + }) + .await + .result?; + let net_iface = Watch::new(net_iface); let forward = PortForwardController::new(net_iface.clone_unseen()); Command::new("sysctl") @@ -111,12 +129,8 @@ impl TunnelContext { for iface in net_iface.peek(|i| { i.iter() - .filter(|(_, info)| { - dbg!(info).ip_info.as_ref().map_or(false, |i| { - dbg!(i).device_type != Some(NetworkInterfaceType::Wireguard) - }) - }) .map(|(name, _)| name) + .filter(|id| id.as_str() != WIREGUARD_INTERFACE_NAME) .cloned() .collect::>() }) { diff --git a/core/startos/src/tunnel/db.rs b/core/startos/src/tunnel/db.rs index 1de5a26b5..ec7cba6e0 100644 --- a/core/startos/src/tunnel/db.rs +++ b/core/startos/src/tunnel/db.rs @@ -1,15 +1,13 @@ -use std::collections::{BTreeMap, BTreeSet}; -use std::net::{IpAddr, SocketAddr, SocketAddrV4}; +use std::collections::BTreeMap; +use std::net::{SocketAddr, SocketAddrV4}; use std::path::PathBuf; use clap::builder::ValueParserFactory; use clap::Parser; -use imbl::HashMap; +use imbl::{HashMap, OrdMap}; use imbl_value::InternedString; use itertools::Itertools; use models::{FromStrParser, GatewayId}; -use openssl::pkey::{PKey, Private}; -use openssl::x509::X509; use patch_db::json_ptr::{JsonPointer, ROOT}; use patch_db::Dump; use rpc_toolkit::yajrc::RpcError; @@ -20,16 +18,14 @@ use ts_rs::TS; use crate::auth::Sessions; use crate::context::CliContext; -use crate::net::ssl::FullchainCertData; +use crate::db::model::public::NetworkInterfaceInfo; use crate::prelude::*; use crate::sign::AnyVerifyingKey; use crate::tunnel::auth::SignerInfo; use crate::tunnel::context::TunnelContext; use crate::tunnel::web::TunnelCertData; use crate::tunnel::wg::WgServer; -use crate::util::serde::{ - apply_expr, deserialize_from_str, serialize_display, HandlerExtSerde, Pem, -}; +use crate::util::serde::{apply_expr, deserialize_from_str, serialize_display, HandlerExtSerde}; #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct GatewayPort(pub GatewayId, pub u16); @@ -84,7 +80,8 @@ pub struct TunnelDatabase { pub sessions: Sessions, pub password: Option, pub auth_pubkeys: HashMap, - pub certificates: BTreeMap>, TunnelCertData>, + pub certificates: Option, + pub gateways: OrdMap, pub wg: WgServer, pub port_forwards: PortForwards, } diff --git a/core/startos/src/tunnel/mod.rs b/core/startos/src/tunnel/mod.rs index 99e81a947..9855eff38 100644 --- a/core/startos/src/tunnel/mod.rs +++ b/core/startos/src/tunnel/mod.rs @@ -79,9 +79,3 @@ pub fn tunnel_router(ctx: TunnelContext) -> Router { }), ) } - -impl WebServer { - pub fn serve_tunnel(&mut self, ctx: TunnelContext) { - self.serve_router(tunnel_router(ctx)) - } -} diff --git a/core/startos/src/tunnel/web.rs b/core/startos/src/tunnel/web.rs index a855ecc74..6eb4312af 100644 --- a/core/startos/src/tunnel/web.rs +++ b/core/startos/src/tunnel/web.rs @@ -10,11 +10,16 @@ use openssl::x509::X509; use rpc_toolkit::{from_fn_async, Context, Empty, HandlerArgs, HandlerExt, ParentHandler}; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncBufReadExt, BufReader}; +use tokio_rustls::rustls::server::ClientHello; +use tokio_rustls::rustls::ServerConfig; use crate::context::CliContext; use crate::net::ssl::SANInfo; +use crate::net::tls::TlsHandler; +use crate::net::web_server::Accept; use crate::prelude::*; use crate::tunnel::context::TunnelContext; +use crate::tunnel::db::TunnelDatabase; use crate::util::serde::Pem; #[derive(Debug, Deserialize, Serialize, Parser)] @@ -26,6 +31,19 @@ pub struct TunnelCertData { pub cert: Pem>, } +pub struct TunnelCertHandler(TypedPatchDb); +impl<'a, A> TlsHandler<'a, A> for TunnelCertHandler +where + A: Accept, +{ + async fn get_config( + &'a mut self, + hello: &'a ClientHello<'a>, + metadata: &'a ::Metadata, + ) -> Option { + } +} + pub fn web_api() -> ParentHandler { ParentHandler::new() .subcommand("init", from_fn_async(init_web_rpc).no_cli()) diff --git a/core/startos/src/tunnel/wg.rs b/core/startos/src/tunnel/wg.rs index 848061d05..5b19f508c 100644 --- a/core/startos/src/tunnel/wg.rs +++ b/core/startos/src/tunnel/wg.rs @@ -13,6 +13,8 @@ use crate::util::io::write_file_atomic; use crate::util::serde::Base64; use crate::util::Invoke; +pub const WIREGUARD_INTERFACE_NAME: &str = "wg-start-tunnel"; + #[derive(Deserialize, Serialize, HasModel)] #[serde(rename_all = "camelCase")] #[model = "Model"] @@ -37,7 +39,7 @@ impl WgServer { pub async fn sync(&self) -> Result<(), Error> { Command::new("wg-quick") .arg("down") - .arg("wg0") + .arg(WIREGUARD_INTERFACE_NAME) .invoke(ErrorKind::Network) .await .or_else(|e| { @@ -49,13 +51,13 @@ impl WgServer { } })?; write_file_atomic( - "/etc/wireguard/wg0.conf", + const_format::formatcp!("/etc/wireguard/{WIREGUARD_INTERFACE_NAME}.conf"), self.server_config().to_string().as_bytes(), ) .await?; Command::new("wg-quick") .arg("up") - .arg("wg0") + .arg(WIREGUARD_INTERFACE_NAME) .invoke(ErrorKind::Network) .await?; Ok(()) diff --git a/core/startos/src/util/sync.rs b/core/startos/src/util/sync.rs index e9d0251eb..40f6dcfbf 100644 --- a/core/startos/src/util/sync.rs +++ b/core/startos/src/util/sync.rs @@ -266,6 +266,9 @@ impl Watch { version: 0, } } + pub fn watcher_count(&self) -> usize { + Arc::strong_count(&self.shared) + } #[cfg_attr(feature = "unstable", inline(never))] pub fn poll_changed(&mut self, cx: &mut std::task::Context<'_>) -> Poll<()> { self.shared.mutate(|shared| { diff --git a/core/startos/src/version/v0_4_0_alpha_12.rs b/core/startos/src/version/v0_4_0_alpha_12.rs index 32fce1c2e..7bbbbb6ee 100644 --- a/core/startos/src/version/v0_4_0_alpha_12.rs +++ b/core/startos/src/version/v0_4_0_alpha_12.rs @@ -75,6 +75,8 @@ impl VersionT for Version { } fix_host(&mut db["public"]["serverInfo"]["network"]["host"])?; + db["private"]["keyStore"]["localCerts"] = db["private"]["keyStore"]["local_certs"].clone(); + Ok(Value::Null) } fn down(self, _db: &mut Value) -> Result<(), Error> { diff --git a/patch-db b/patch-db index 5b23a1eac..90b336d6a 160000 --- a/patch-db +++ b/patch-db @@ -1 +1 @@ -Subproject commit 5b23a1eac6439fff930c9b0eafae7680339dcd4b +Subproject commit 90b336d6a98d3c89798a69138786c51427a80e7d diff --git a/sdk/base/lib/osBindings/NetworkInterfaceType.ts b/sdk/base/lib/osBindings/NetworkInterfaceType.ts index e20067dcc..d04c37ca2 100644 --- a/sdk/base/lib/osBindings/NetworkInterfaceType.ts +++ b/sdk/base/lib/osBindings/NetworkInterfaceType.ts @@ -1,3 +1,8 @@ // This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. -export type NetworkInterfaceType = "ethernet" | "wireless" | "wireguard" +export type NetworkInterfaceType = + | "ethernet" + | "wireless" + | "bridge" + | "wireguard" + | "loopback" diff --git a/sdk/base/lib/osBindings/SignerInfo.ts b/sdk/base/lib/osBindings/SignerInfo.ts index 76cbdafce..7e7aa2588 100644 --- a/sdk/base/lib/osBindings/SignerInfo.ts +++ b/sdk/base/lib/osBindings/SignerInfo.ts @@ -1,3 +1,9 @@ // This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. +import type { AnyVerifyingKey } from "./AnyVerifyingKey" +import type { ContactInfo } from "./ContactInfo" -export type SignerInfo = { name: string } +export type SignerInfo = { + name: string + contact: Array + keys: Array +} diff --git a/sdk/base/lib/util/ip.ts b/sdk/base/lib/util/ip.ts index a631b6b23..c7a26102e 100644 --- a/sdk/base/lib/util/ip.ts +++ b/sdk/base/lib/util/ip.ts @@ -80,6 +80,8 @@ export const PRIVATE_IPV4_RANGES = [ new IpNet("192.168.0.0/16"), ] +export const IPV4_LOOPBACK = new IpNet("127.0.0.0/8") +export const IPV6_LOOPBACK = new IpNet("::1/128") export const IPV6_LINK_LOCAL = new IpNet("fe80::/10") export const CGNAT = new IpNet("100.64.0.0/10") diff --git a/web/projects/ui/src/app/services/gateway.service.ts b/web/projects/ui/src/app/services/gateway.service.ts index 5650144d6..3bb61ba07 100644 --- a/web/projects/ui/src/app/services/gateway.service.ts +++ b/web/projects/ui/src/app/services/gateway.service.ts @@ -24,6 +24,11 @@ export class GatewayService { map(gateways => Object.entries(gateways) .filter(([_, val]) => !!val?.ipInfo) + .filter( + ([_, val]) => + val?.ipInfo?.deviceType !== 'bridge' && + val?.ipInfo?.deviceType !== 'loopback', + ) .map(([id, val]) => { const subnets = val.ipInfo?.subnets.map(s => utils.IpNet.parse(s)) ?? []