wip: tls refactor

This commit is contained in:
Aiden McClelland
2025-10-24 09:25:30 -06:00
parent 2056d4def1
commit 82a3a435f5
23 changed files with 1743 additions and 879 deletions

548
core/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,15 @@
use std::ffi::OsString;
use std::time::Duration;
use clap::Parser;
use futures::FutureExt;
use helpers::NonDetachingJoinHandle;
use rpc_toolkit::CliApp;
use tokio::signal::unix::signal;
use tracing::instrument;
use crate::context::CliContext;
use crate::context::config::ClientConfig;
use crate::context::CliContext;
use crate::net::web_server::{Acceptor, WebServer};
use crate::prelude::*;
use crate::tunnel::context::{TunnelConfig, TunnelContext};
@@ -17,13 +19,51 @@ use crate::util::logger::LOGGER;
async fn inner_main(config: &TunnelConfig) -> Result<(), Error> {
let server = async {
let ctx = TunnelContext::init(config).await?;
let mut server = WebServer::new(Acceptor::bind([ctx.listen]).await?);
let listen = ctx.listen;
let mut server = WebServer::new(Acceptor::bind([listen]).await?);
let https_thread: NonDetachingJoinHandle<()> = tokio::spawn(async move {
let mut sub = setter_db.subscribe("/webserver".parse().unwrap()).await;
while sub.recv().await.is_some() {
while let Err(e) = async {
let external = setter_db.peek().await.into_webserver().de()?;
let mut bind_err = None;
setter.send_modify(|a| {
a.retain(|a, _| *a == listen || Some(*a) == external);
if let Some(external) = external {
if !a.contains_key(&external) {
match mio::net::TcpListener::bind(external) {
Ok(l) => {
a.insert(external, TcpListener::from_std(l.into()));
}
Err(e) => bind_err = Some(e),
}
}
}
});
if let Some(e) = bind_err {
return Err(e);
}
Ok::<_, Error>(())
}
.await
{
tracing::error!("error updating webserver bind: {e}");
tracing::debug!("{e:?}");
tokio::time::sleep(Duration::from_secs(5)).await;
}
}
})
.into();
server.serve_tunnel(ctx.clone());
let mut shutdown_recv = ctx.shutdown.subscribe();
let sig_handler_ctx = ctx;
let sig_handler = tokio::spawn(async move {
let sig_handler: NonDetachingJoinHandle<()> = tokio::spawn(async move {
use tokio::signal::unix::SignalKind;
futures::future::select_all(
[
@@ -48,14 +88,16 @@ async fn inner_main(config: &TunnelConfig) -> Result<(), Error> {
.send(())
.map_err(|_| ())
.expect("send shutdown signal");
});
})
.into();
shutdown_recv
.recv()
.await
.with_kind(crate::ErrorKind::Unknown)?;
sig_handler.abort();
sig_handler.wait_for_abort().await;
setter_thread.wait_for_abort().await;
Ok::<_, Error>(server)
}

View File

@@ -1,6 +1,7 @@
pub mod model;
pub mod prelude;
use std::panic::UnwindSafe;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
@@ -12,7 +13,7 @@ use itertools::Itertools;
use patch_db::json_ptr::{JsonPointer, ROOT};
use patch_db::{DiffPatch, Dump, Revision};
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use rpc_toolkit::{from_fn_async, Context, HandlerArgs, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::{self, UnboundedReceiver};
use tokio::sync::watch;
@@ -23,12 +24,22 @@ use crate::context::{CliContext, RpcContext};
use crate::prelude::*;
use crate::rpc_continuations::{Guid, RpcContinuation};
use crate::util::net::WebSocketExt;
use crate::util::serde::{HandlerExtSerde, apply_expr};
use crate::util::serde::{apply_expr, HandlerExtSerde};
lazy_static::lazy_static! {
static ref PUBLIC: JsonPointer = "/public".parse().unwrap();
}
pub trait DbAccess<T>: Sized {
type Key<'a>;
fn access<'a>(db: &'a Model<Self>, key: Self::Key<'_>) -> &'a Model<T>;
}
pub trait DbAccessMut<T>: Sized {
type Key<'a>;
fn access_mut<'a>(db: &'a mut Model<Self>, key: Self::Key<'_>) -> &'a mut Model<T>;
}
pub fn db<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand(

View File

@@ -1,5 +1,6 @@
use std::collections::{BTreeMap, BTreeSet, VecDeque};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use chrono::{DateTime, Utc};
use exver::{Version, VersionRange};
@@ -8,7 +9,6 @@ use imbl_value::InternedString;
use ipnet::IpNet;
use isocountry::CountryCode;
use itertools::Itertools;
use lazy_static::lazy_static;
use models::{GatewayId, PackageId};
use openssl::hash::MessageDigest;
use patch_db::{HasModel, Value};
@@ -18,7 +18,6 @@ use ts_rs::TS;
use crate::account::AccountInfo;
use crate::db::model::package::AllPackageData;
use crate::net::acme::AcmeProvider;
use crate::net::forward::START9_BRIDGE_IFACE;
use crate::net::host::binding::{AddSslOptions, BindInfo, BindOptions, NetInfo};
use crate::net::host::Host;
use crate::net::utils::ipv6_is_local;
@@ -30,7 +29,7 @@ use crate::util::cpupower::Governor;
use crate::util::lshw::LshwDevice;
use crate::util::serde::MaybeUtf8String;
use crate::version::{Current, VersionT};
use crate::{ARCH, HOST_IP, PLATFORM};
use crate::{ARCH, PLATFORM};
#[derive(Debug, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
@@ -216,40 +215,9 @@ pub struct NetworkInterfaceInfo {
pub name: Option<InternedString>,
pub public: Option<bool>,
pub secure: Option<bool>,
pub ip_info: Option<IpInfo>,
pub ip_info: Option<Arc<IpInfo>>,
}
impl NetworkInterfaceInfo {
pub fn loopback() -> (&'static GatewayId, &'static Self) {
lazy_static! {
static ref LO: GatewayId = GatewayId::from(InternedString::intern("lo"));
static ref LOOPBACK: NetworkInterfaceInfo = NetworkInterfaceInfo {
name: Some(InternedString::from_static("Loopback")),
public: Some(false),
secure: Some(true),
ip_info: Some(IpInfo {
name: "lo".into(),
scope_id: 1,
device_type: None,
subnets: [
IpNet::new(Ipv4Addr::LOCALHOST.into(), 8).unwrap(),
IpNet::new(Ipv6Addr::LOCALHOST.into(), 128).unwrap(),
]
.into_iter()
.collect(),
lan_ip: [
IpAddr::from(Ipv4Addr::LOCALHOST),
IpAddr::from(Ipv6Addr::LOCALHOST)
]
.into_iter()
.collect(),
wan_ip: None,
ntp_servers: Default::default(),
dns_servers: Default::default(),
}),
};
}
(&*LO, &*LOOPBACK)
}
pub fn public(&self) -> bool {
self.public.unwrap_or_else(|| {
!self.ip_info.as_ref().map_or(true, |ip_info| {
@@ -309,7 +277,7 @@ pub struct IpInfo {
pub dns_servers: OrdSet<IpAddr>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize, TS)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, TS)]
#[ts(export)]
#[serde(rename_all = "kebab-case")]
pub enum NetworkInterfaceType {

View File

@@ -193,7 +193,7 @@ where
A: serde::Serialize + serde::de::DeserializeOwned + Ord,
B: serde::Serialize + serde::de::DeserializeOwned,
{
type Key = A;
type Key = JsonKey<A>;
type Value = B;
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
serde_json::to_string(key).with_kind(ErrorKind::Serialization)
@@ -433,6 +433,12 @@ impl<T> std::ops::DerefMut for JsonKey<T> {
&mut self.0
}
}
impl<T: DeserializeOwned> FromStr for JsonKey<T> {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
serde_json::from_str(s).with_kind(ErrorKind::Deserialization)
}
}
impl<T: Serialize> Serialize for JsonKey<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
@@ -445,7 +451,7 @@ impl<T: Serialize> Serialize for JsonKey<T> {
}
}
// { "foo": "bar" } -> "{ \"foo\": \"bar\" }"
impl<'de, T: Serialize + DeserializeOwned> Deserialize<'de> for JsonKey<T> {
impl<'de, T: DeserializeOwned> Deserialize<'de> for JsonKey<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,

View File

@@ -22,10 +22,11 @@ use crate::db::model::Database;
use crate::developer::OS_DEVELOPER_KEY_PATH;
use crate::hostname::Hostname;
use crate::middleware::auth::AuthContext;
use crate::net::gateway::UpgradableListener;
use crate::net::net_controller::{NetController, NetService};
use crate::net::socks::DEFAULT_SOCKS_LISTEN;
use crate::net::utils::find_wifi_iface;
use crate::net::web_server::{UpgradableListener, WebServerAcceptorSetter};
use crate::net::web_server::WebServerAcceptorSetter;
use crate::prelude::*;
use crate::progress::{
FullProgress, FullProgressTracker, PhaseProgressTrackerHandle, PhasedProgressBar, ProgressUnits,

View File

@@ -1,24 +1,229 @@
use std::collections::{BTreeMap, BTreeSet};
use std::net::IpAddr;
use std::str::FromStr;
use std::sync::Arc;
use async_acme::acme::Identifier;
use clap::Parser;
use async_acme::acme::{Identifier, ACME_TLS_ALPN_NAME};
use clap::builder::ValueParserFactory;
use clap::Parser;
use futures::StreamExt;
use imbl_value::InternedString;
use itertools::Itertools;
use models::{ErrorData, FromStrParser};
use openssl::pkey::{PKey, Private};
use openssl::x509::X509;
use rpc_toolkit::{Context, HandlerExt, ParentHandler, from_fn_async};
use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use tokio_rustls::rustls::crypto::CryptoProvider;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio_rustls::rustls::sign::CertifiedKey;
use tokio_rustls::rustls::ServerConfig;
use ts_rs::TS;
use url::Url;
use crate::context::{CliContext, RpcContext};
use crate::db::model::Database;
use crate::db::model::public::AcmeSettings;
use crate::db::model::Database;
use crate::db::{DbAccess, DbAccessMut};
use crate::net::tls::{SingleCertResolver, TlsHandler};
use crate::net::web_server::Accept;
use crate::prelude::*;
use crate::util::serde::{Pem, Pkcs8Doc};
use crate::util::sync::{SyncMutex, Watch};
pub type AcmeTlsAlpnCache =
Arc<SyncMutex<BTreeMap<InternedString, Watch<Option<Arc<CertifiedKey>>>>>>;
pub struct AcmeTlsHandler<'a, M: HasModel, S: 'a> {
pub db: &'a TypedPatchDb<M>,
pub acme_cache: &'a AcmeTlsAlpnCache,
pub crypto_provider: &'a Arc<CryptoProvider>,
pub get_provider: S,
pub in_progress: Watch<BTreeSet<BTreeSet<InternedString>>>,
}
impl<'b, M, S> AcmeTlsHandler<'b, M, S>
where
for<'a> M: DbAccess<AcmeCertStore, Key<'a> = ()>
+ DbAccess<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ DbAccessMut<AcmeCertStore, Key<'a> = ()>
+ HasModel<Model = Model<M>>
+ Send
+ Sync,
S: GetAcmeProvider<'b> + Clone + 'b,
{
pub async fn get_cert(&self, san_info: &BTreeSet<InternedString>) -> Option<CertifiedKey> {
let provider = self.get_provider.clone().get_provider(san_info).await?;
loop {
let peek = self.db.peek().await;
let store = <M as DbAccess<AcmeCertStore>>::access(&peek, ());
if let Some(cert) = store
.as_certs()
.as_idx(&provider.0)
.and_then(|p| p.as_idx(JsonKey::new_ref(san_info)))
{
let cert = cert.de().log_err()?;
return Some(
CertifiedKey::from_der(
cert.fullchain
.into_iter()
.map(|c| Ok(CertificateDer::from(c.to_der()?)))
.collect::<Result<_, Error>>()
.log_err()?,
PrivateKeyDer::from(PrivatePkcs8KeyDer::from(
cert.key.0.private_key_to_pkcs8().log_err()?,
)),
&*self.crypto_provider,
)
.log_err()?,
);
}
if !self.in_progress.send_if_modified(|x| {
if !x.contains(san_info) {
x.insert(san_info.clone());
true
} else {
false
}
}) {
self.in_progress
.clone()
.wait_for(|x| !x.contains(san_info))
.await;
continue;
}
let contact = <M as DbAccess<AcmeSettings>>::access(&peek, provider)
.as_contact()
.de()
.log_err()?;
let identifiers: Vec<_> = san_info
.iter()
.map(|d| match d.parse::<IpAddr>() {
Ok(a) => Identifier::Ip(a),
_ => Identifier::Dns((&**d).into()),
})
.collect::<Vec<_>>();
let cache_entries = san_info
.iter()
.cloned()
.map(|d| (d, Watch::new(None)))
.collect::<BTreeMap<_, _>>();
self.acme_cache.mutate(|c| {
c.extend(cache_entries.iter().map(|(k, v)| (k.clone(), v.clone())));
});
let cert = async_acme::rustls_helper::order(
|identifier, cert| {
let domain = InternedString::from_display(&identifier);
if let Some(entry) = cache_entries.get(&domain) {
entry.send(Some(Arc::new(cert)));
}
Ok(())
},
provider.0.as_str(),
&identifiers,
Some(&AcmeCertCache(&self.db)),
&contact,
)
.await
.log_err()?;
self.acme_cache
.mutate(|c| c.retain(|c, _| !cache_entries.contains_key(c)));
self.in_progress.send_modify(|i| i.remove(san_info));
return Some(cert);
}
}
}
pub trait GetAcmeProvider<'a> {
fn get_provider<'b>(
self,
san_info: &'b BTreeSet<InternedString>,
) -> impl Future<Output = Option<&'a AcmeProvider>> + Send + 'b
where
Self: 'b;
}
impl<'b, A, M, S> TlsHandler<A> for &'b AcmeTlsHandler<'b, M, S>
where
A: Accept,
<A as Accept>::Metadata: Send + Sync,
for<'a> M: DbAccess<AcmeCertStore, Key<'a> = ()>
+ DbAccess<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ DbAccessMut<AcmeCertStore, Key<'a> = ()>
+ HasModel<Model = Model<M>>
+ Send
+ Sync,
S: GetAcmeProvider<'b> + Clone + Send + Sync + 'b,
{
async fn get_config<'a>(
self,
hello: &'a tokio_rustls::rustls::server::ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig>
where
Self: 'a,
A: 'a,
<A as Accept>::Metadata: 'a,
{
let domain = hello.server_name()?;
if hello
.alpn()
.into_iter()
.flatten()
.any(|a| a == ACME_TLS_ALPN_NAME)
{
let cert = self
.acme_cache
.peek(|c| c.get(domain).cloned())
.ok_or_else(|| {
Error::new(
eyre!("No challenge recv available for {domain}"),
ErrorKind::OpenSsl,
)
})
.log_err()?;
tracing::info!("Waiting for verification cert for {domain}");
let cert = cert
.filter(|c| futures::future::ready(c.is_some()))
.next()
.await
.flatten()?;
tracing::info!("Verification cert received for {domain}");
let mut cfg = ServerConfig::builder_with_provider(self.crypto_provider.clone())
.with_safe_default_protocol_versions()
.log_err()?
.with_no_client_auth()
.with_cert_resolver(Arc::new(SingleCertResolver(cert)));
cfg.alpn_protocols = vec![ACME_TLS_ALPN_NAME.to_vec()];
tracing::info!("performing ACME auth challenge");
return Some(cfg);
}
let domains: BTreeSet<InternedString> = [domain.into()].into_iter().collect();
let crypto_provider = self.crypto_provider.clone();
if let Some(cert) = self.get_cert(&domains).await {
return Some(
ServerConfig::builder_with_provider(crypto_provider)
.with_safe_default_protocol_versions()
.log_err()?
.with_no_client_auth()
.with_cert_resolver(Arc::new(SingleCertResolver(Arc::new(cert)))),
);
}
None
}
}
#[derive(Debug, Default, Deserialize, Serialize, HasModel)]
#[model = "Model<Self>"]
@@ -32,29 +237,41 @@ impl AcmeCertStore {
}
}
impl DbAccess<AcmeCertStore> for Database {
type Key<'a> = ();
fn access<'a>(db: &'a Model<Self>, _: Self::Key<'_>) -> &'a Model<AcmeCertStore> {
db.as_private().as_key_store().as_acme()
}
}
impl DbAccessMut<AcmeCertStore> for Database {
type Key<'a> = ();
fn access_mut<'a>(db: &'a mut Model<Self>, _: Self::Key<'_>) -> &'a mut Model<AcmeCertStore> {
db.as_private_mut().as_key_store_mut().as_acme_mut()
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct AcmeCert {
pub key: Pem<PKey<Private>>,
pub fullchain: Vec<Pem<X509>>,
}
pub struct AcmeCertCache<'a>(pub &'a TypedPatchDb<Database>);
pub struct AcmeCertCache<'a, M: HasModel>(pub &'a TypedPatchDb<M>);
#[async_trait::async_trait]
impl<'a> async_acme::cache::AcmeCache for AcmeCertCache<'a> {
impl<'a, M> async_acme::cache::AcmeCache for AcmeCertCache<'a, M>
where
for<'b> M: HasModel<Model = Model<M>>
+ DbAccess<AcmeCertStore, Key<'b> = ()>
+ DbAccessMut<AcmeCertStore, Key<'b> = ()>
+ Send
+ Sync,
{
type Error = ErrorData;
async fn read_account(&self, contacts: &[&str]) -> Result<Option<Vec<u8>>, Self::Error> {
let contacts = JsonKey::new(contacts.into_iter().map(|s| (*s).to_owned()).collect_vec());
let Some(account) = self
.0
.peek()
.await
.into_private()
.into_key_store()
.into_acme()
.into_accounts()
.into_idx(&contacts)
else {
let peek = self.0.peek().await;
let Some(account) = M::access(&peek, ()).as_accounts().as_idx(&contacts) else {
return Ok(None);
};
Ok(Some(account.de()?.0.document.into_vec()))
@@ -68,9 +285,7 @@ impl<'a> async_acme::cache::AcmeCache for AcmeCertCache<'a> {
};
self.0
.mutate(|db| {
db.as_private_mut()
.as_key_store_mut()
.as_acme_mut()
M::access_mut(db, ())
.as_accounts_mut()
.insert(&contacts, &Pem::new(key))
})
@@ -96,16 +311,11 @@ impl<'a> async_acme::cache::AcmeCache for AcmeCertCache<'a> {
let directory_url = directory_url
.parse::<Url>()
.with_kind(ErrorKind::ParseUrl)?;
let Some(cert) = self
.0
.peek()
.await
.into_private()
.into_key_store()
.into_acme()
.into_certs()
.into_idx(&directory_url)
.and_then(|a| a.into_idx(&identifiers))
let peek = self.0.peek().await;
let Some(cert) = M::access(&peek, ())
.as_certs()
.as_idx(&directory_url)
.and_then(|a| a.as_idx(&identifiers))
else {
return Ok(None);
};
@@ -160,9 +370,7 @@ impl<'a> async_acme::cache::AcmeCache for AcmeCertCache<'a> {
};
self.0
.mutate(|db| {
db.as_private_mut()
.as_key_store_mut()
.as_acme_mut()
M::access_mut(db, ())
.as_certs_mut()
.upsert(&directory_url, || Ok(BTreeMap::new()))?
.insert(&identifiers, &cert)

View File

@@ -415,7 +415,6 @@ impl Resolver {
{
if let Some(res) = self.net_iface.peek(|i| {
i.values()
.chain([NetworkInterfaceInfo::loopback().1])
.filter_map(|i| i.ip_info.as_ref())
.find(|i| i.subnets.iter().any(|s| s.contains(&src)))
.map(|ip_info| {

View File

@@ -3,10 +3,11 @@ use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::future::Future;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV6};
use std::sync::{Arc, Weak};
use std::task::Poll;
use std::task::{ready, Poll};
use std::time::Duration;
use clap::Parser;
use futures::future::Either;
use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
use helpers::NonDetachingJoinHandle;
use imbl::{OrdMap, OrdSet};
@@ -19,10 +20,11 @@ use patch_db::json_ptr::JsonPointer;
use rpc_toolkit::{from_fn_async, Context, HandlerArgs, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::net::{TcpListener, TcpStream};
use tokio::net::TcpListener;
use tokio::process::Command;
use tokio::sync::oneshot;
use ts_rs::TS;
use visit_rs::{Visit, VisitFields};
use zbus::proxy::{PropertyChanged, PropertyStream, SignalStream};
use zbus::zvariant::{
DeserializeDict, Dict, OwnedObjectPath, OwnedValue, Type as ZType, Value as ZValue,
@@ -35,7 +37,7 @@ use crate::db::model::Database;
use crate::net::forward::START9_BRIDGE_IFACE;
use crate::net::gateway::device::DeviceProxy;
use crate::net::utils::ipv6_is_link_local;
use crate::net::web_server::Accept;
use crate::net::web_server::{Accept, AcceptStream, Acceptor, MetadataVisitor};
use crate::prelude::*;
use crate::util::collections::OrdMapIterMut;
use crate::util::future::Until;
@@ -714,6 +716,7 @@ async fn watch_ip(
)
});
ip_info.wan_ip = ip_info.wan_ip.or(prev_wan_ip);
let ip_info = Arc::new(ip_info);
m.insert(
iface.clone(),
NetworkInterfaceInfo {
@@ -828,7 +831,7 @@ impl NetworkInterfaceWatcher {
self.ip_info.read()
}
pub fn bind(&self, port: u16) -> Result<NetworkInterfaceListener, Error> {
pub fn bind<B: Bind>(&self, bind: B, port: u16) -> Result<NetworkInterfaceListener<B>, Error> {
let arc = Arc::new(());
self.listeners.mutate(|l| {
if l.get(&port).filter(|w| w.strong_count() > 0).is_some() {
@@ -841,22 +844,20 @@ impl NetworkInterfaceWatcher {
Ok(())
})?;
let ip_info = self.ip_info.clone_unseen();
let activated = self.activated.clone_unseen();
Ok(NetworkInterfaceListener {
_arc: arc,
ip_info,
activated,
listeners: ListenerMap::new(port),
listeners: ListenerMap::new(bind, port),
})
}
pub fn upgrade_listener(
pub fn upgrade_listener<B: Bind>(
&self,
SelfContainedNetworkInterfaceListener {
mut listener,
..
}: SelfContainedNetworkInterfaceListener,
) -> Result<NetworkInterfaceListener, Error> {
}: SelfContainedNetworkInterfaceListener<B>,
) -> Result<NetworkInterfaceListener<B>, Error> {
let port = listener.listeners.port;
let arc = &listener._arc;
self.listeners.mutate(|l| {
@@ -1169,45 +1170,6 @@ impl NetworkInterfaceController {
}
}
struct ListenerMap {
prev_filter: DynInterfaceFilter,
port: u16,
listeners: BTreeMap<SocketAddr, (TcpListener, Option<Ipv4Addr>)>,
}
impl ListenerMap {
fn from_listener(listener: impl IntoIterator<Item = TcpListener>) -> Result<Self, Error> {
let mut port = 0;
let mut listeners = BTreeMap::<SocketAddr, (TcpListener, Option<Ipv4Addr>)>::new();
for listener in listener {
let mut local = listener.local_addr().with_kind(ErrorKind::Network)?;
if let SocketAddr::V6(l) = &mut local {
if ipv6_is_link_local(*l.ip()) && l.scope_id() == 0 {
continue; // TODO determine scope id
}
}
if port != 0 && port != local.port() {
return Err(Error::new(
eyre!("Provided listeners are bound to different ports"),
ErrorKind::InvalidRequest,
));
}
port = local.port();
listeners.insert(local, (listener, None));
}
if port == 0 {
return Err(Error::new(
eyre!("Listener array cannot be empty"),
ErrorKind::InvalidRequest,
));
}
Ok(Self {
prev_filter: false.into_dyn(),
port,
listeners,
})
}
}
pub trait InterfaceFilter: Any + Clone + std::fmt::Debug + Eq + Ord + Send + Sync {
fn filter(&self, id: &GatewayId, info: &NetworkInterfaceInfo) -> bool;
fn eq(&self, other: &dyn Any) -> bool {
@@ -1235,6 +1197,14 @@ impl InterfaceFilter for bool {
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct TypeFilter(pub NetworkInterfaceType);
impl InterfaceFilter for TypeFilter {
fn filter(&self, _: &GatewayId, info: &NetworkInterfaceInfo) -> bool {
info.ip_info.as_ref().and_then(|i| i.device_type) == Some(self.0)
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct IdFilter(pub GatewayId);
impl InterfaceFilter for IdFilter {
@@ -1366,10 +1336,17 @@ impl Ord for DynInterfaceFilter {
}
}
impl ListenerMap {
fn new(port: u16) -> Self {
struct ListenerMap<B: Bind> {
prev_filter: DynInterfaceFilter,
bind: B,
port: u16,
listeners: BTreeMap<SocketAddr, B::Accept>,
}
impl<B: Bind> ListenerMap<B> {
fn new(bind: B, port: u16) -> Self {
Self {
prev_filter: false.into_dyn(),
bind,
port,
listeners: BTreeMap::new(),
}
@@ -1384,7 +1361,6 @@ impl ListenerMap {
let mut keep = BTreeSet::<SocketAddr>::new();
for (_, info) in ip_info
.iter()
.chain([NetworkInterfaceInfo::loopback()])
.filter(|(id, info)| filter.filter(*id, *info))
{
if let Some(ip_info) = &info.ip_info {
@@ -1404,24 +1380,9 @@ impl ListenerMap {
ip => SocketAddr::new(ip, self.port),
};
keep.insert(addr);
if let Some((_, wan_ip)) = self.listeners.get_mut(&addr) {
*wan_ip = info.ip_info.as_ref().and_then(|i| i.wan_ip);
continue;
if !self.listeners.contains_key(&addr) {
self.listeners.insert(addr, self.bind.bind(addr)?);
}
self.listeners.insert(
addr,
(
TcpListener::from_std(
mio::net::TcpListener::bind(addr)
.with_ctx(|_| {
(ErrorKind::Network, lazy_format!("binding to {addr:?}"))
})?
.into(),
)
.with_kind(ErrorKind::Network)?,
info.ip_info.as_ref().and_then(|i| i.wan_ip),
),
);
}
}
}
@@ -1429,24 +1390,13 @@ impl ListenerMap {
self.prev_filter = filter.clone().into_dyn();
Ok(())
}
fn poll_accept(&self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>> {
for (bind_addr, (listener, wan_ip)) in self.listeners.iter() {
if let Poll::Ready((stream, addr)) = listener.poll_accept(cx)? {
if let Err(e) = socket2::SockRef::from(&stream).set_tcp_keepalive(
&socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(900))
.with_interval(Duration::from_secs(60))
.with_retries(5),
) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
return Poll::Ready(Ok(Accepted {
stream,
peer: addr,
wan_ip: *wan_ip,
bind: *bind_addr,
}));
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(SocketAddr, <B::Accept as Accept>::Metadata, AcceptStream), Error>> {
for (addr, listener) in self.listeners.iter_mut() {
if let Poll::Ready((metadata, stream)) = listener.poll_accept(cx)? {
return Poll::Ready(Ok((*addr, metadata, stream)));
}
}
Poll::Pending
@@ -1457,54 +1407,100 @@ pub fn lookup_info_by_addr(
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
addr: SocketAddr,
) -> Option<(&GatewayId, &NetworkInterfaceInfo)> {
ip_info
.iter()
.chain([NetworkInterfaceInfo::loopback()])
.find(|(_, i)| {
i.ip_info
.as_ref()
.map_or(false, |i| i.subnets.iter().any(|i| i.addr() == addr.ip()))
})
ip_info.iter().find(|(_, i)| {
i.ip_info
.as_ref()
.map_or(false, |i| i.subnets.iter().any(|i| i.addr() == addr.ip()))
})
}
pub struct NetworkInterfaceListener {
pub trait Bind {
type Accept: Accept;
fn bind(&mut self, addr: SocketAddr) -> Result<Self::Accept, Error>;
}
#[derive(Clone, Copy, Default)]
pub struct BindTcp;
impl Bind for BindTcp {
type Accept = TcpListener;
fn bind(&mut self, addr: SocketAddr) -> Result<Self::Accept, Error> {
TcpListener::from_std(
mio::net::TcpListener::bind(addr)
.with_kind(ErrorKind::Network)?
.into(),
)
.with_kind(ErrorKind::Network)
}
}
pub trait FromGatewayInfo {
fn from_gateway_info(id: &GatewayId, info: &NetworkInterfaceInfo) -> Self;
}
#[derive(Clone, Debug)]
pub struct GatewayInfo {
pub id: GatewayId,
pub info: NetworkInterfaceInfo,
}
impl<V: MetadataVisitor> Visit<V> for GatewayInfo {
fn visit(&self, visitor: &mut V) -> <V as visit_rs::Visitor>::Result {
visitor.visit(self)
}
}
impl FromGatewayInfo for GatewayInfo {
fn from_gateway_info(id: &GatewayId, info: &NetworkInterfaceInfo) -> Self {
Self {
id: id.clone(),
info: info.clone(),
}
}
}
pub struct NetworkInterfaceListener<B: Bind = BindTcp> {
pub ip_info: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
activated: Watch<BTreeMap<GatewayId, bool>>,
listeners: ListenerMap,
listeners: ListenerMap<B>,
_arc: Arc<()>,
}
impl NetworkInterfaceListener {
impl<B: Bind> NetworkInterfaceListener<B> {
pub(super) fn new(
mut ip_info: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
bind: B,
port: u16,
) -> Self {
ip_info.mark_unseen();
Self {
ip_info,
listeners: ListenerMap::new(bind, port),
_arc: Arc::new(()),
}
}
pub fn port(&self) -> u16 {
self.listeners.port
}
#[cfg_attr(feature = "unstable", inline(never))]
pub fn poll_accept(
pub fn poll_accept<M: FromGatewayInfo>(
&mut self,
cx: &mut std::task::Context<'_>,
filter: &impl InterfaceFilter,
) -> Poll<Result<Accepted, Error>> {
) -> Poll<Result<(M, <B::Accept as Accept>::Metadata, AcceptStream), Error>> {
while self.ip_info.poll_changed(cx).is_ready()
|| !DynInterfaceFilterT::eq(&self.listeners.prev_filter, filter.as_any())
{
self.ip_info
.peek_and_mark_seen(|ip_info| self.listeners.update(ip_info, filter))?;
}
self.listeners.poll_accept(cx)
}
pub(super) fn new(
mut ip_info: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
activated: Watch<BTreeMap<GatewayId, bool>>,
port: u16,
) -> Self {
ip_info.mark_unseen();
Self {
ip_info,
activated,
listeners: ListenerMap::new(port),
_arc: Arc::new(()),
}
let (addr, inner, stream) = ready!(self.listeners.poll_accept(cx)?);
Poll::Ready(Ok((
self.ip_info
.peek(|ip_info| {
lookup_info_by_addr(ip_info, addr)
.map(|(id, info)| M::from_gateway_info(id, info))
})
.or_not_found(lazy_format!("gateway for {addr}"))?,
inner,
stream,
)))
}
pub fn change_ip_info_source(
@@ -1515,7 +1511,10 @@ impl NetworkInterfaceListener {
self.ip_info = ip_info;
}
pub async fn accept(&mut self, filter: &impl InterfaceFilter) -> Result<Accepted, Error> {
pub async fn accept<M: FromGatewayInfo>(
&mut self,
filter: &impl InterfaceFilter,
) -> Result<(M, <B::Accept as Accept>::Metadata, AcceptStream), Error> {
futures::future::poll_fn(|cx| self.poll_accept(cx, filter)).await
}
@@ -1531,40 +1530,73 @@ impl NetworkInterfaceListener {
}
}
pub struct Accepted {
pub stream: TcpStream,
pub peer: SocketAddr,
pub wan_ip: Option<Ipv4Addr>,
pub bind: SocketAddr,
#[derive(VisitFields)]
pub struct NetworkInterfaceListenerAcceptMetadata<B: Bind> {
pub inner: <B::Accept as Accept>::Metadata,
pub info: GatewayInfo,
}
pub struct SelfContainedNetworkInterfaceListener {
_watch_thread: NonDetachingJoinHandle<()>,
listener: NetworkInterfaceListener,
}
impl SelfContainedNetworkInterfaceListener {
pub fn bind(port: u16) -> Self {
let ip_info = Watch::new(OrdMap::new());
let activated = Watch::new(
[(
GatewayId::from(InternedString::from(START9_BRIDGE_IFACE)),
false,
)]
.into_iter()
.collect(),
);
let _watch_thread = tokio::spawn(watcher(ip_info.clone(), activated.clone())).into();
Self {
_watch_thread,
listener: NetworkInterfaceListener::new(ip_info, activated, port),
}
impl<B, V> Visit<V> for NetworkInterfaceListenerAcceptMetadata<B>
where
B: Bind,
<B::Accept as Accept>::Metadata: Visit<V> + Clone + Send + Sync + 'static,
V: MetadataVisitor,
{
fn visit(&self, visitor: &mut V) -> V::Result {
self.visit_fields(visitor).collect()
}
}
impl Accept for SelfContainedNetworkInterfaceListener {
impl<B: Bind> Accept for NetworkInterfaceListener<B> {
type Metadata = NetworkInterfaceListenerAcceptMetadata<B>;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<super::web_server::Accepted, Error>> {
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
NetworkInterfaceListener::poll_accept(self, cx, &true).map(|res| {
res.map(|(info, inner, stream)| {
(
NetworkInterfaceListenerAcceptMetadata { inner, info },
stream,
)
})
})
}
}
pub struct SelfContainedNetworkInterfaceListener<B: Bind = BindTcp> {
_watch_thread: NonDetachingJoinHandle<()>,
listener: NetworkInterfaceListener<B>,
}
impl<B: Bind> SelfContainedNetworkInterfaceListener<B> {
pub fn bind(bind: B, port: u16) -> Self {
let ip_info = Watch::new(OrdMap::new());
let _watch_thread =
tokio::spawn(watcher(ip_info.clone(), Watch::new(BTreeMap::new()))).into();
Self {
_watch_thread,
listener: NetworkInterfaceListener::new(ip_info, bind, port),
}
}
}
impl<B: Bind> Accept for SelfContainedNetworkInterfaceListener<B> {
type Metadata = <NetworkInterfaceListener<B> as Accept>::Metadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(Self::Metadata, AcceptStream), Error>> {
Accept::poll_accept(&mut self.listener, cx)
}
}
pub type UpgradableListener<B = BindTcp> =
Option<Either<SelfContainedNetworkInterfaceListener<B>, NetworkInterfaceListener<B>>>;
impl<B> Acceptor<UpgradableListener<B>>
where
B: Bind + Send + Sync + 'static,
B::Accept: Send + Sync,
{
pub fn bind_upgradable(listener: SelfContainedNetworkInterfaceListener<B>) -> Self {
Self::new(Some(Either::Left(listener)))
}
}

View File

@@ -12,6 +12,7 @@ pub mod service_interface;
pub mod socks;
pub mod ssl;
pub mod static_server;
pub mod tls;
pub mod tor;
pub mod tunnel;
pub mod utils;

View File

@@ -19,7 +19,7 @@ use crate::net::dns::DnsController;
use crate::net::forward::{PortForwardController, START9_BRIDGE_IFACE};
use crate::net::gateway::{
AndFilter, DynInterfaceFilter, IdFilter, InterfaceFilter, NetworkInterfaceController, OrFilter,
PublicFilter, SecureFilter,
PublicFilter, SecureFilter, TypeFilter,
};
use crate::net::host::address::HostAddress;
use crate::net::host::binding::{AddSslOptions, BindId, BindOptions};
@@ -28,7 +28,7 @@ use crate::net::service_interface::{GatewayInfo, HostnameInfo, IpHostname, Onion
use crate::net::socks::SocksController;
use crate::net::tor::{OnionAddress, TorController, TorSecretKey};
use crate::net::utils::ipv6_is_local;
use crate::net::vhost::{AlpnInfo, TargetInfo, VHostController};
use crate::net::vhost::{AlpnInfo, ProxyTarget, VHostController};
use crate::prelude::*;
use crate::service::effects::callbacks::ServiceCallbacks;
use crate::util::serde::MaybeUtf8String;
@@ -134,7 +134,7 @@ impl NetController {
#[derive(Default, Debug)]
struct HostBinds {
forwards: BTreeMap<u16, (SocketAddr, DynInterfaceFilter, Arc<()>)>,
vhosts: BTreeMap<(Option<InternedString>, u16), (TargetInfo, Arc<()>)>,
vhosts: BTreeMap<(Option<InternedString>, u16), (ProxyTarget, Arc<()>)>,
private_dns: BTreeMap<InternedString, Arc<()>>,
tor: BTreeMap<OnionAddress, (OrdMap<u16, SocketAddr>, Vec<Arc<()>>)>,
}
@@ -226,7 +226,7 @@ impl NetServiceData {
async fn update(&mut self, ctrl: &NetController, id: HostId, host: Host) -> Result<(), Error> {
let mut forwards: BTreeMap<u16, (SocketAddr, DynInterfaceFilter)> = BTreeMap::new();
let mut vhosts: BTreeMap<(Option<InternedString>, u16), TargetInfo> = BTreeMap::new();
let mut vhosts: BTreeMap<(Option<InternedString>, u16), ProxyTarget> = BTreeMap::new();
let mut private_dns: BTreeSet<InternedString> = BTreeSet::new();
let mut tor: BTreeMap<OnionAddress, (TorSecretKey, OrdMap<u16, SocketAddr>)> =
BTreeMap::new();
@@ -263,7 +263,7 @@ impl NetServiceData {
for hostname in ctrl.server_hostnames.iter().cloned() {
vhosts.insert(
(hostname, external),
TargetInfo {
ProxyTarget {
filter: bind.net.clone().into_dyn(),
acme: None,
addr,
@@ -278,11 +278,9 @@ impl NetServiceData {
if hostnames.insert(hostname.clone()) {
vhosts.insert(
(Some(hostname), external),
TargetInfo {
ProxyTarget {
filter: OrFilter(
IdFilter(
NetworkInterfaceInfo::loopback().0.clone(),
),
TypeFilter(NetworkInterfaceType::Loopback),
IdFilter(GatewayId::from(InternedString::from(
START9_BRIDGE_IFACE,
))),
@@ -306,7 +304,7 @@ impl NetServiceData {
if let Some(public) = &public {
vhosts.insert(
(address.clone(), 5443),
TargetInfo {
ProxyTarget {
filter: AndFilter(
bind.net.clone(),
AndFilter(
@@ -322,7 +320,7 @@ impl NetServiceData {
);
vhosts.insert(
(address.clone(), 443),
TargetInfo {
ProxyTarget {
filter: AndFilter(
bind.net.clone(),
if private {
@@ -348,7 +346,7 @@ impl NetServiceData {
} else {
vhosts.insert(
(address.clone(), 443),
TargetInfo {
ProxyTarget {
filter: AndFilter(
bind.net.clone(),
PublicFilter { public: false },
@@ -364,7 +362,7 @@ impl NetServiceData {
if let Some(public) = public {
vhosts.insert(
(address.clone(), external),
TargetInfo {
ProxyTarget {
filter: AndFilter(
bind.net.clone(),
if private {
@@ -387,7 +385,7 @@ impl NetServiceData {
} else {
vhosts.insert(
(address.clone(), external),
TargetInfo {
ProxyTarget {
filter: AndFilter(
bind.net.clone(),
PublicFilter { public: false },

View File

@@ -8,7 +8,7 @@ use std::time::UNIX_EPOCH;
use async_compression::tokio::bufread::GzipEncoder;
use axum::body::Body;
use axum::extract::{self as x, Request};
use axum::response::{Redirect, Response};
use axum::response::{IntoResponse, Redirect, Response};
use axum::routing::{any, get};
use axum::Router;
use base64::display::Base64Display;
@@ -37,6 +37,8 @@ use crate::main_api;
use crate::middleware::auth::{Auth, HasValidSession};
use crate::middleware::cors::Cors;
use crate::middleware::db::SyncDb;
use crate::net::gateway::GatewayInfo;
use crate::net::tls::TlsHandshakeInfo;
use crate::prelude::*;
use crate::rpc_continuations::{Guid, RpcContinuations};
use crate::s9pk::merkle_archive::source::http::HttpSource;
@@ -80,6 +82,30 @@ impl UiContext for RpcContext {
.middleware(SyncDb::new())
}
fn extend_router(self, router: Router) -> Router {
async fn https_redirect_if_public_http(
req: Request,
next: axum::middleware::Next,
) -> Response {
if req
.extensions()
.get::<GatewayInfo>()
.map_or(false, |p| p.info.public())
&& req.extensions().get::<TlsHandshakeInfo>().is_none()
{
Redirect::temporary(&format!(
"https://{}{}",
req.headers()
.get(HOST)
.and_then(|s| s.to_str().ok())
.unwrap_or("localhost"),
req.uri()
))
.into_response()
} else {
next.run(req).await
}
}
router
.route("/proxy/{url}", {
let ctx = self.clone();
@@ -103,6 +129,7 @@ impl UiContext for RpcContext {
}
}),
)
.layer(axum::middleware::from_fn(https_redirect_if_public_http))
}
}
@@ -229,20 +256,6 @@ pub fn refresher() -> Router {
}))
}
pub fn redirecter() -> Router {
Router::new().fallback(get(|request: Request| async move {
Redirect::temporary(&format!(
"https://{}{}",
request
.headers()
.get(HOST)
.and_then(|s| s.to_str().ok())
.unwrap_or("localhost"),
request.uri()
))
}))
}
async fn proxy_request(ctx: RpcContext, request: Request, url: String) -> Result<Response, Error> {
if_authorized(&ctx, request, |mut request| async {
for header in PROXY_STRIP_HEADERS {

282
core/startos/src/net/tls.rs Normal file
View File

@@ -0,0 +1,282 @@
use std::sync::Arc;
use std::task::Poll;
use futures::future::BoxFuture;
use futures::FutureExt;
use imbl_value::InternedString;
use tokio::io::AsyncWriteExt;
use tokio_rustls::rustls::server::{Acceptor, ClientHello, ResolvesServerCert};
use tokio_rustls::rustls::sign::CertifiedKey;
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::LazyConfigAcceptor;
use visit_rs::{Visit, VisitFields};
use crate::net::web_server::{Accept, AcceptStream, MetadataVisitor};
use crate::prelude::*;
use crate::util::io::{BackTrackingIO, ReadWriter};
use crate::util::serde::MaybeUtf8String;
#[derive(Debug, Clone, VisitFields)]
pub struct TlsMetadata<M> {
pub inner: M,
pub tls_info: TlsHandshakeInfo,
}
impl<V: MetadataVisitor<Result = ()>, M: Visit<V>> Visit<V> for TlsMetadata<M> {
fn visit(&self, visitor: &mut V) -> <V as visit_rs::Visitor>::Result {
self.visit_fields(visitor).collect()
}
}
#[derive(Debug, Clone)]
pub struct TlsHandshakeInfo {
pub sni: Option<InternedString>,
pub alpn: Vec<MaybeUtf8String>,
}
impl<V: MetadataVisitor> Visit<V> for TlsHandshakeInfo {
fn visit(&self, visitor: &mut V) -> <V as visit_rs::Visitor>::Result {
visitor.visit(self)
}
}
pub trait TlsHandler<A: Accept> {
fn get_config<'a>(
self,
hello: &'a ClientHello<'a>,
metadata: &'a A::Metadata,
) -> impl Future<Output = Option<ServerConfig>> + Send + 'a
where
Self: 'a,
A: 'a,
A::Metadata: 'a;
}
#[derive(Clone)]
pub struct ChainedHandler<H0, H1>(pub H0, pub H1);
impl<A, H0, H1> TlsHandler<A> for ChainedHandler<H0, H1>
where
A: Accept,
<A as Accept>::Metadata: Send + Sync,
H0: TlsHandler<A> + Send,
H1: TlsHandler<A> + Send,
{
async fn get_config<'a>(
self,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig>
where
Self: 'a,
{
if let Some(config) = self.0.get_config(hello, metadata).await {
return Some(config);
}
self.1.get_config(hello, metadata).await
}
}
pub struct TlsHandlerWrapper<I, W> {
pub inner: I,
pub wrapper: W,
}
pub trait WrapTlsHandler<A: Accept> {
fn wrap<'a>(
self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> impl Future<Output = Option<ServerConfig>> + Send + 'a
where
Self: 'a;
}
impl<A, I, W> TlsHandler<A> for TlsHandlerWrapper<I, W>
where
A: Accept,
<A as Accept>::Metadata: Send + Sync,
I: TlsHandler<A> + Send,
W: WrapTlsHandler<A> + Send,
{
async fn get_config<'a>(
self,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig>
where
Self: 'a,
A: 'a,
<A as Accept>::Metadata: 'a,
{
let prev = self.inner.get_config(hello, metadata).await?;
self.wrapper.wrap(prev, hello, metadata).await
}
}
#[derive(Debug)]
pub struct SingleCertResolver(pub Arc<CertifiedKey>);
impl ResolvesServerCert for SingleCertResolver {
fn resolve(&self, _: ClientHello) -> Option<Arc<CertifiedKey>> {
Some(self.0.clone())
}
}
pub struct TlsListener<A: Accept, H: TlsHandler<A>> {
pub accept: A,
pub tls_handler: H,
in_progress:
Vec<BoxFuture<'static, Result<Option<(TlsMetadata<A::Metadata>, AcceptStream)>, Error>>>,
}
impl<A: Accept, H: TlsHandler<A>> TlsListener<A, H> {
pub fn new(accept: A, cert_handler: H) -> Self {
Self {
accept,
tls_handler: cert_handler,
in_progress: Vec::new(),
}
}
}
impl<A, H> Accept for TlsListener<A, H>
where
A: Accept,
A::Metadata: Send + 'static,
H: TlsHandler<A> + Clone + Send + 'static,
{
type Metadata = TlsMetadata<A::Metadata>;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
if let Some((idx, res)) = self
.in_progress
.iter_mut()
.enumerate()
.find_map(|(idx, fut)| match fut.poll_unpin(cx) {
Poll::Ready(a) => Some((idx, a)),
Poll::Pending => None,
})
{
self.in_progress.swap_remove(idx);
if let Some(res) = res.transpose() {
return Poll::Ready(res);
}
}
if let Poll::Ready((metadata, stream)) = self.accept.poll_accept(cx)? {
let tls_handler = self.tls_handler.clone();
self.in_progress.push(
async move {
let mut acceptor =
LazyConfigAcceptor::new(Acceptor::default(), BackTrackingIO::new(stream));
let mut mid: tokio_rustls::StartHandshake<BackTrackingIO<AcceptStream>> =
match (&mut acceptor).await {
Ok(a) => a,
Err(e) => {
let mut stream = acceptor.take_io().or_not_found("acceptor io")?;
let (_, buf) = stream.rewind();
if std::str::from_utf8(buf)
.ok()
.and_then(|buf| {
buf.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty())
.next()
})
.map_or(false, |buf| {
regex::Regex::new("[A-Z]+ (.+) HTTP/1")
.unwrap()
.is_match(buf)
})
{
handle_http_on_https(stream).await.log_err();
return Ok(None);
} else {
return Err(e).with_kind(ErrorKind::Network);
}
}
};
let hello = mid.client_hello();
if let Some(cfg) = tls_handler.get_config(&hello, &metadata).await {
let metadata = TlsMetadata {
inner: metadata,
tls_info: TlsHandshakeInfo {
sni: hello.server_name().map(InternedString::intern),
alpn: hello
.alpn()
.into_iter()
.flatten()
.map(|a| MaybeUtf8String(a.to_vec()))
.collect(),
},
};
let buffered = mid.io.stop_buffering();
mid.io
.write_all(&buffered)
.await
.with_kind(ErrorKind::Network)?;
return Ok(Some((
metadata,
Box::pin(mid.into_stream(Arc::new(cfg)).await?) as AcceptStream,
)));
}
Ok(None)
}
.boxed(),
);
}
Poll::Pending
}
}
async fn handle_http_on_https(stream: impl ReadWriter + Unpin + 'static) -> Result<(), Error> {
use axum::body::Body;
use axum::extract::Request;
use axum::response::Response;
use http::Uri;
use crate::net::static_server::server_error;
hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
.serve_connection(
hyper_util::rt::TokioIo::new(stream),
hyper_util::service::TowerToHyperService::new(axum::Router::new().fallback(
axum::routing::method_routing::any(move |req: Request| async move {
match async move {
let host = req
.headers()
.get(http::header::HOST)
.and_then(|host| host.to_str().ok());
if let Some(host) = host {
let uri = Uri::from_parts({
let mut parts = req.uri().to_owned().into_parts();
parts.scheme = Some("https".parse()?);
parts.authority = Some(host.parse()?);
parts
})?;
Response::builder()
.status(http::StatusCode::TEMPORARY_REDIRECT)
.header(http::header::LOCATION, uri.to_string())
.body(Body::default())
} else {
Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(Body::from("Host header required"))
}
}
.await
{
Ok(a) => a,
Err(e) => {
tracing::warn!("Error redirecting http request on ssl port: {e}");
tracing::error!("{e:?}");
server_error(Error::new(e, ErrorKind::Network))
}
}
}),
)),
)
.await
.map_err(|e| Error::new(color_eyre::eyre::Report::msg(e), ErrorKind::Network))
}

View File

@@ -5,15 +5,62 @@ use async_stream::try_stream;
use color_eyre::eyre::eyre;
use futures::stream::BoxStream;
use futures::{StreamExt, TryStreamExt};
use imbl::OrdMap;
use imbl_value::InternedString;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use models::GatewayId;
use nix::net::if_::if_nametoindex;
use tokio::net::{TcpListener, TcpStream};
use tokio::process::Command;
use crate::db::model::public::{NetworkInterfaceInfo, NetworkInterfaceType};
use crate::prelude::*;
use crate::util::collections::OrdMapIterMut;
use crate::util::Invoke;
pub async fn load_network_interface_info() -> Result<OrdMap<GatewayId, NetworkInterfaceInfo>, Error>
{
let output = String::from_utf8(
Command::new("ip")
.arg("-o")
.arg("addr")
.arg("show")
.invoke(crate::ErrorKind::Network)
.await?,
)?;
let err_fn = || {
Error::new(
eyre!("malformed output from `ip`"),
crate::ErrorKind::Network,
)
};
let mut res = OrdMap::<GatewayId, NetworkInterfaceInfo>::new();
for line in output.lines() {
let split = line.split_ascii_whitespace().collect::<Vec<_>>();
let iface = GatewayId::from(InternedString::from(*split.get(1).ok_or_else(&err_fn)?));
let subnet: IpNet = split.get(3).ok_or_else(&err_fn)?.parse()?;
let info = res.entry(iface).or_default();
let ip_info = info.ip_info.get_or_insert_default();
ip_info.scope_id = split
.get(0)
.ok_or_else(&err_fn)?
.strip_suffix(":")
.ok_or_else(&err_fn)?
.parse()?;
ip_info.subnets.insert(subnet);
}
for (id, info) in OrdMapIterMut::from(&mut res) {
let ip_info = info.ip_info.get_or_insert_default();
ip_info.device_type = probe_iface_type(id.as_str()).await;
}
Ok(res)
}
pub fn ipv6_is_link_local(addr: Ipv6Addr) -> bool {
(addr.segments()[0] & 0xffc0) == 0xfe80
}
@@ -75,6 +122,22 @@ pub async fn get_iface_ipv6_addr(iface: &str) -> Result<Option<(Ipv6Addr, Ipv6Ne
.transpose()?)
}
pub async fn probe_iface_type(iface: &str) -> Option<NetworkInterfaceType> {
match tokio::fs::read_to_string(Path::new("/sys/class/net").join(iface).join("uevent"))
.await
.ok()?
.lines()
.find_map(|l| l.strip_prefix("DEVTYPE="))
{
Some("wlan") => Some(NetworkInterfaceType::Wireless),
Some("bridge") => Some(NetworkInterfaceType::Bridge),
Some("wireguard") => Some(NetworkInterfaceType::Wireguard),
None if iface_is_physical(iface).await => Some(NetworkInterfaceType::Ethernet),
None if iface_is_loopback(iface).await => Some(NetworkInterfaceType::Loopback),
_ => None,
}
}
pub async fn iface_is_physical(iface: &str) -> bool {
tokio::fs::metadata(Path::new("/sys/class/net").join(iface).join("device"))
.await
@@ -87,6 +150,19 @@ pub async fn iface_is_wireless(iface: &str) -> bool {
.is_ok()
}
pub async fn iface_is_bridge(iface: &str) -> bool {
tokio::fs::metadata(Path::new("/sys/class/net").join(iface).join("bridge"))
.await
.is_ok()
}
pub async fn iface_is_loopback(iface: &str) -> bool {
tokio::fs::read_to_string(Path::new("/sys/class/net").join(iface).join("type"))
.await
.ok()
.map_or(false, |x| x.trim() == "772")
}
pub fn list_interfaces() -> BoxStream<'static, Result<String, Error>> {
try_stream! {
let mut ifaces = tokio::fs::read_dir("/sys/class/net").await?;

View File

@@ -1,20 +1,22 @@
use std::any::Any;
use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, Weak};
use async_acme::acme::{ACME_TLS_ALPN_NAME, Identifier};
use async_acme::acme::{Identifier, ACME_TLS_ALPN_NAME};
use axum::body::Body;
use axum::extract::Request;
use axum::response::Response;
use color_eyre::eyre::eyre;
use futures::future::BoxFuture;
use futures::FutureExt;
use helpers::NonDetachingJoinHandle;
use http::Uri;
use http::{Extensions, Uri};
use imbl::OrdMap;
use imbl_value::InternedString;
use imbl_value::{InOMap, InternedString};
use itertools::Itertools;
use models::{GatewayId, ResultExt};
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn};
use rpc_toolkit::{from_fn, Context, HandlerArgs, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
@@ -23,29 +25,29 @@ use tokio_rustls::rustls::crypto::CryptoProvider;
use tokio_rustls::rustls::pki_types::{
CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName,
};
use tokio_rustls::rustls::server::{Acceptor, ResolvesServerCert};
use tokio_rustls::rustls::sign::CertifiedKey;
use tokio_rustls::rustls::{RootCertStore, ServerConfig};
use tokio_rustls::rustls::server::{Acceptor, ClientHello};
use tokio_rustls::rustls::{ClientConfig, RootCertStore, ServerConfig};
use tokio_rustls::{LazyConfigAcceptor, TlsConnector};
use tokio_stream::StreamExt;
use tokio_stream::wrappers::WatchStream;
use tokio_stream::StreamExt;
use tracing::instrument;
use ts_rs::TS;
use crate::context::{CliContext, RpcContext};
use crate::db::model::Database;
use crate::db::model::public::NetworkInterfaceInfo;
use crate::net::acme::{AcmeCertCache, AcmeProvider};
use crate::db::model::Database;
use crate::net::acme::{AcmeCertCache, AcmeTlsAlpnCache, AcmeTlsHandler};
use crate::net::gateway::{
Accepted, AnyFilter, DynInterfaceFilter, InterfaceFilter, NetworkInterfaceController,
NetworkInterfaceListener,
AnyFilter, BindTcp, DynInterfaceFilter, GatewayInfo, InterfaceFilter, NetworkInterfaceController, NetworkInterfaceListener
};
use crate::net::static_server::server_error;
use crate::net::tls::{ChainedHandler, TlsHandler, TlsListener, WrapTlsHandler};
use crate::net::web_server::{Accept, AcceptStream, extract};
use crate::prelude::*;
use crate::util::collections::EqSet;
use crate::util::io::BackTrackingIO;
use crate::util::serde::{HandlerExtSerde, MaybeUtf8String, display_serializable};
use crate::util::sync::SyncMutex;
use crate::util::serde::{display_serializable, HandlerExtSerde, MaybeUtf8String};
use crate::util::sync::{SyncMutex, Watch};
pub fn vhost_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new().subcommand(
@@ -61,8 +63,7 @@ pub fn vhost_api<C: Context>() -> ParentHandler<C> {
}
let mut table = Table::new();
table
.add_row(row![bc => "FROM", "TO", "GATEWAYS", "ACME", "CONNECT SSL", "ACTIVE"]);
table.add_row(row![bc => "FROM", "TO", "GATEWAYS", "CONNECT SSL", "ACTIVE"]);
for (external, targets) in res {
for (host, targets) in targets {
@@ -75,7 +76,6 @@ pub fn vhost_api<C: Context>() -> ParentHandler<C> {
),
target.addr,
target.gateways.iter().join(", "),
target.acme.as_ref().map(|a| a.0.as_str()).unwrap_or("NONE"),
target.connect_ssl.is_ok(),
idx == 0
]);
@@ -91,22 +91,14 @@ pub fn vhost_api<C: Context>() -> ParentHandler<C> {
)
}
#[derive(Debug)]
struct SingleCertResolver(Arc<CertifiedKey>);
impl ResolvesServerCert for SingleCertResolver {
fn resolve(&self, _: tokio_rustls::rustls::server::ClientHello) -> Option<Arc<CertifiedKey>> {
Some(self.0.clone())
}
}
// not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353
pub struct VHostController {
db: TypedPatchDb<Database>,
interfaces: Arc<NetworkInterfaceController>,
crypto_provider: Arc<CryptoProvider>,
acme_tls_alpn_cache: AcmeTlsAlpnCache,
servers: SyncMutex<BTreeMap<u16, VHostServer>>,
acme_cache: AcmeTlsAlpnCache,
servers: SyncMutex<BTreeMap<u16, VHostServer<NetworkInterfaceListener>>>,
}
impl VHostController {
pub fn new(db: TypedPatchDb<Database>, interfaces: Arc<NetworkInterfaceController>) -> Self {
@@ -114,7 +106,7 @@ impl VHostController {
db,
interfaces,
crypto_provider: Arc::new(tokio_rustls::rustls::crypto::ring::default_provider()),
acme_tls_alpn_cache: Arc::new(SyncMutex::new(BTreeMap::new())),
acme_cache: Arc::new(SyncMutex::new(BTreeMap::new())),
servers: SyncMutex::new(BTreeMap::new()),
}
}
@@ -123,19 +115,18 @@ impl VHostController {
&self,
hostname: Option<InternedString>,
external: u16,
target: TargetInfo,
target: impl VHostTarget<NetworkInterfaceListener>,
) -> Result<Arc<()>, Error> {
self.servers.mutate(|writable| {
let server = if let Some(server) = writable.remove(&external) {
server
} else {
VHostServer::new(
external,
self.interfaces.watcher.bind(BindTcp, external)?,
self.db.clone(),
self.interfaces.clone(),
self.crypto_provider.clone(),
self.acme_tls_alpn_cache.clone(),
)?
self.acme_cache.clone(),
)
};
let rc = server.add(hostname, target);
writable.insert(external, server);
@@ -185,43 +176,158 @@ impl VHostController {
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct TargetInfo {
pub filter: DynInterfaceFilter,
pub acme: Option<AcmeProvider>,
pub addr: SocketAddr,
pub connect_ssl: Result<(), AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn
pub trait VHostTarget<A: Accept>: std::fmt::Debug + Eq {
type PreprocessRes: Send + 'static;
#[allow(unused_variables)]
fn skip(&self, metadata: &<A as Accept>::Metadata) -> bool {
false
}
fn preprocess<'a>(
&'a self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> impl Future<Output = Option<(ServerConfig, Self::PreprocessRes)>> + Send + 'a;
fn handle_stream(&self, stream: AcceptStream, prev: Self::PreprocessRes);
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
pub struct ShowTargetInfo {
pub gateways: BTreeSet<GatewayId>,
pub acme: Option<AcmeProvider>,
pub addr: SocketAddr,
pub connect_ssl: Result<(), AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn
pub trait DynVHostTargetT<A: Accept>: std::fmt::Debug + Any {
fn skip(&self, metadata: &<A as Accept>::Metadata) -> bool;
fn preprocess<'a>(
&'a self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> BoxFuture<'a, Option<(ServerConfig, Box<dyn Any + Send>)>>;
fn handle_stream(&self, stream: AcceptStream, prev: Box<dyn Any + Send>);
fn eq(&self, other: &dyn DynVHostTargetT<A>) -> bool;
}
impl ShowTargetInfo {
pub fn new(
TargetInfo {
filter,
acme,
addr,
connect_ssl,
}: TargetInfo,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
) -> Self {
ShowTargetInfo {
gateways: ip_info
.iter()
.filter(|(id, info)| filter.filter(*id, *info))
.map(|(k, _)| k)
.cloned()
.collect(),
acme,
addr,
connect_ssl,
impl<A: Accept, T: VHostTarget<A> + 'static> DynVHostTargetT<A> for T {
fn skip(&self, metadata: &<A as Accept>::Metadata) -> bool {
VHostTarget::skip(self, metadata)
}
fn preprocess<'a>(
&'a self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> BoxFuture<'a, Option<(ServerConfig, Box<dyn Any + Send>)>> {
VHostTarget::preprocess(self, prev, hello, metadata)
.map(|o| o.map(|(cfg, res)| (cfg, Box::new(res) as Box<dyn Any + Send>)))
.boxed()
}
fn handle_stream(&self, stream: AcceptStream, prev: Box<dyn Any + Send>) {
if let Ok(prev) = prev.downcast() {
VHostTarget::handle_stream(self, stream, *prev);
}
}
fn eq(&self, other: &dyn DynVHostTargetT<A>) -> bool {
Some(self) == (other as &dyn Any).downcast_ref()
}
}
struct DynVHostTarget<A: Accept>(Arc<dyn DynVHostTargetT<A> + Send + Sync>);
impl<A: Accept> Clone for DynVHostTarget<A> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<A: Accept + 'static> PartialEq for DynVHostTarget<A> {
fn eq(&self, other: &Self) -> bool {
self.0.eq(&*other.0)
}
}
impl<A: Accept + 'static> Eq for DynVHostTarget<A> {}
struct Preprocessed<A: Accept>(DynVHostTarget<A>, Box<dyn Any + Send>);
impl<A: Accept + 'static> DynVHostTarget<A> {
async fn into_preprocessed(
self,
prev: ServerConfig,
hello: &ClientHello<'_>,
metadata: &<A as Accept>::Metadata,
) -> Option<(ServerConfig, Preprocessed<A>)> {
let (cfg, res) = self.0.preprocess(prev, hello, metadata).await?;
Some((cfg, Preprocessed(self, res)))
}
}
impl<A: Accept + 'static> Preprocessed<A> {
fn finish(self, stream: AcceptStream) {
(self.0).0.handle_stream(stream, self.1);
}
}
#[derive(Debug, Clone)]
pub struct ProxyTarget {
pub filter: DynInterfaceFilter,
pub addr: SocketAddr,
pub connect_ssl: Result<Arc<ClientConfig>, AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn
}
impl PartialEq for ProxyTarget {
fn eq(&self, other: &Self) -> bool {
self.filter == other.filter
&& self.addr == other.addr
&& self.connect_ssl.as_ref().err() == other.connect_ssl.as_ref().err()
}
}
impl Eq for ProxyTarget {}
impl<A> VHostTarget<A> for ProxyTarget
where
A: Accept + 'static,
<A as Accept>::Metadata: Send + Sync,
{
type PreprocessRes = AcceptStream;
fn skip(&self, metadata: &<A as Accept>::Metadata) -> bool {
let info = extract::<GatewayInfo,_>(metadata)
}
async fn preprocess<'a>(
&'a self,
mut prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> Option<(ServerConfig, Self::PreprocessRes)> {
let tcp_stream = TcpStream::connect(self.addr)
.await
.with_ctx(|_| (ErrorKind::Network, self.addr))
.log_err()?;
match &self.connect_ssl {
Ok(client_cfg) => {
let mut client_cfg = (&**client_cfg).clone();
client_cfg.alpn_protocols = hello
.alpn()
.into_iter()
.flatten()
.map(|x| x.to_vec())
.collect();
let target_stream = TlsConnector::from(Arc::new(client_cfg))
.connect_with(
ServerName::IpAddress(self.addr.ip().into()),
tcp_stream,
|conn| {
prev.alpn_protocols
.extend(conn.alpn_protocol().into_iter().map(|p| p.to_vec()))
},
)
.await
.log_err()?;
return Some((prev, Box::pin(target_stream)));
}
Err(AlpnInfo::Reflect) => {
for alpn in hello.alpn().into_iter().flatten() {
prev.alpn_protocols.push(alpn.to_vec());
}
}
Err(AlpnInfo::Specified(a)) => {
for alpn in a {
prev.alpn_protocols.push(alpn.0.clone());
}
}
}
Some((prev, Box::pin(tcp_stream)))
}
fn handle_stream(&self, mut stream: AcceptStream, mut prev: Self::PreprocessRes) {
tokio::spawn(async move { tokio::io::copy_bidirectional(&mut stream, &mut prev).await });
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, TS)]
@@ -237,17 +343,57 @@ impl Default for AlpnInfo {
}
}
type AcmeTlsAlpnCache =
Arc<SyncMutex<BTreeMap<InternedString, watch::Receiver<Option<Arc<CertifiedKey>>>>>>;
type Mapping = BTreeMap<Option<InternedString>, BTreeMap<TargetInfo, Weak<()>>>;
type Mapping<A: Accept> = BTreeMap<Option<InternedString>, InOMap<DynVHostTarget<A>, Weak<()>>>;
struct VHostServer {
mapping: watch::Sender<Mapping>,
pub struct VHostConnector<'a, A: Accept + 'static>(&'a Watch<Mapping<A>>, Option<Preprocessed<A>>);
impl<'b, A> WrapTlsHandler<A> for &'b mut VHostConnector<'b, A>
where
A: Accept + 'static,
<A as Accept>::Metadata: Send + Sync,
{
async fn wrap<'a>(
self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig>
where
Self: 'a,
{
if hello
.alpn()
.into_iter()
.flatten()
.any(|a| a == ACME_TLS_ALPN_NAME)
{
return Some(prev);
}
let target = self.0.peek(|m| {
m.get(&hello.server_name().map(InternedString::from))
.into_iter()
.flatten()
.filter(|(_, rc)| rc.strong_count() > 0)
.find(|(t, _)| !t.skip(metadata))
.map(|(e, _)| e.clone())
})?;
let (prev, store) = target.into_preprocessed(prev, hello, metadata).await?;
self.1 = Some(store);
Some(prev)
}
}
struct VHostServer<A: Accept + 'static> {
mapping: Watch<Mapping<A>>,
_thread: NonDetachingJoinHandle<()>,
}
impl<'a> From<&'a BTreeMap<Option<InternedString>, BTreeMap<TargetInfo, Weak<()>>>> for AnyFilter {
fn from(value: &'a BTreeMap<Option<InternedString>, BTreeMap<TargetInfo, Weak<()>>>) -> Self {
impl<'a> From<&'a BTreeMap<Option<InternedString>, BTreeMap<ProxyTarget, Weak<()>>>> for AnyFilter {
fn from(value: &'a BTreeMap<Option<InternedString>, BTreeMap<ProxyTarget, Weak<()>>>) -> Self {
Self(
value
.iter()
@@ -690,34 +836,35 @@ impl VHostServer {
}
#[instrument(skip_all)]
fn new(
port: u16,
fn new<A: Accept>(
listener: A,
db: TypedPatchDb<Database>,
iface_ctrl: Arc<NetworkInterfaceController>,
crypto_provider: Arc<CryptoProvider>,
acme_tls_alpn_cache: AcmeTlsAlpnCache,
acme_cache: AcmeTlsAlpnCache,
) -> Result<Self, Error> {
let mut listener = iface_ctrl
.watcher
.bind(port)
.with_kind(crate::ErrorKind::Network)?;
let (map_send, map_recv) = watch::channel(BTreeMap::new());
let mapping = Watch::new(BTreeMap::new());
Ok(Self {
mapping: map_send,
mapping: mapping.clone(),
_thread: tokio::spawn(async move {
let listener = TlsListener::new(
listener,
VHostTlsHandler {
cert_handler: ChainedHandler(
&AcmeTlsHandler {
db: &db,
acme_cache: &acme_cache,
crypto_provider: &crypto_provider,
get_provider: todo!(),
in_progress: Watch::new(BTreeSet::new()),
},
todo!(),
),
alpn_handler: todo!(),
},
);
loop {
if let Err(e) = Self::accept(
&mut listener,
map_recv.clone(),
db.clone(),
acme_tls_alpn_cache.clone(),
crypto_provider.clone(),
)
.await
{
tracing::error!(
"VHostController: failed to accept connection on {port}: {e}"
);
if let Err(e) = Self::accept(&mut listener, &mapping).await {
tracing::error!("VHostController: failed to accept connection: {e}");
tracing::debug!("{e:?}");
}
}
@@ -725,7 +872,7 @@ impl VHostServer {
.into(),
})
}
fn add(&self, hostname: Option<InternedString>, target: TargetInfo) -> Result<Arc<()>, Error> {
fn add(&self, hostname: Option<InternedString>, target: ProxyTarget) -> Result<Arc<()>, Error> {
let mut res = Ok(Arc::new(()));
self.mapping.send_if_modified(|writable| {
let mut changed = false;

View File

@@ -1,73 +1,136 @@
use std::any::Any;
use std::future::Future;
use std::net::SocketAddr;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;
use axum::extract::ConnectInfo;
use axum::Router;
use futures::future::Either;
use futures::FutureExt;
use helpers::NonDetachingJoinHandle;
use http::Extensions;
use hyper_util::rt::{TokioIo, TokioTimer};
use tokio::net::{TcpListener, TcpStream};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use visit_rs::{Visit, Visitor};
use crate::net::gateway::{
lookup_info_by_addr, NetworkInterfaceListener, SelfContainedNetworkInterfaceListener,
};
use crate::net::static_server::{redirecter, refresher, ui_router, UiContext};
use crate::net::static_server::{ui_router, UiContext};
use crate::prelude::*;
use crate::util::actor::background::BackgroundJobQueue;
use crate::util::io::ReadWriter;
use crate::util::sync::{SyncRwLock, Watch};
pub struct Accepted {
pub type AcceptStream = Pin<Box<dyn ReadWriter + Send + 'static>>;
pub trait MetadataVisitor: Visitor<Result = ()> {
fn visit<M: Clone + Send + Sync + 'static>(&mut self, metadata: &M) -> Self::Result;
}
pub struct ExtensionVisitor<'a>(&'a mut Extensions);
impl<'a> Visitor for ExtensionVisitor<'a> {
type Result = ();
}
impl<'a> MetadataVisitor for ExtensionVisitor<'a> {
fn visit<M: Clone + Send + Sync + 'static>(&mut self, metadata: &M) -> Self::Result {
self.0.insert(metadata.clone());
}
}
pub struct ExtractVisitor<T>(Option<T>);
impl<T> Visitor for ExtractVisitor<T> {
type Result = ();
}
impl<T: Clone + Send + Sync + 'static> MetadataVisitor for ExtractVisitor<T> {
fn visit<M: Clone + Send + Sync + 'static>(&mut self, metadata: &M) -> Self::Result {
if let Some(matching) = (metadata as &dyn Any).downcast_ref::<T>() {
self.0 = Some(matching.clone());
}
}
}
pub fn extract<T, M: Visit<ExtractVisitor<T>>>(metadata: &M) -> Option<T> {
let mut visitor = ExtractVisitor(None);
visitor.visit(metadata);
visitor.0
}
#[derive(Clone, Copy, Debug)]
pub struct TcpMetadata {
pub peer_addr: SocketAddr,
pub local_addr: SocketAddr,
pub https_redirect: bool,
pub stream: TcpStream,
}
impl<V: MetadataVisitor> Visit<V> for TcpMetadata {
fn visit(&self, visitor: &mut V) -> <V as visit_rs::Visitor>::Result {
visitor.visit(self)
}
}
pub trait Accept {
fn poll_accept(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>>;
type Metadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>>;
}
impl Accept for Vec<TcpListener> {
fn poll_accept(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>> {
for listener in &*self {
if let Poll::Ready((stream, peer_addr)) = listener.poll_accept(cx)? {
return Poll::Ready(Ok(Accepted {
local_addr: listener.local_addr()?,
impl Accept for TcpListener {
type Metadata = TcpMetadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
if let Poll::Ready((stream, peer_addr)) = TcpListener::poll_accept(self, cx)? {
if let Err(e) = socket2::SockRef::from(&stream).set_tcp_keepalive(
&socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(900))
.with_interval(Duration::from_secs(60))
.with_retries(5),
) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
return Poll::Ready(Ok((
TcpMetadata {
local_addr: self.local_addr()?,
peer_addr,
https_redirect: false,
stream,
}));
},
Box::pin(stream),
)));
}
Poll::Pending
}
}
impl<A> Accept for Vec<A>
where
A: Accept,
{
type Metadata = A::Metadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
for listener in self {
if let Poll::Ready(accepted) = listener.poll_accept(cx)? {
return Poll::Ready(Ok(accepted));
}
}
Poll::Pending
}
}
impl Accept for NetworkInterfaceListener {
fn poll_accept(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>> {
NetworkInterfaceListener::poll_accept(self, cx, &true).map(|res| {
res.map(|a| {
let public = self
.ip_info
.peek(|i| lookup_info_by_addr(i, a.bind).map_or(true, |(_, i)| i.public()));
Accepted {
peer_addr: a.peer,
local_addr: a.bind,
https_redirect: public,
stream: a.stream,
}
})
})
}
}
impl<A: Accept, B: Accept> Accept for Either<A, B> {
fn poll_accept(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>> {
impl<A, B> Accept for Either<A, B>
where
A: Accept,
B: Accept<Metadata = A::Metadata>,
{
type Metadata = A::Metadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
match self {
Either::Left(a) => a.poll_accept(cx),
Either::Right(b) => b.poll_accept(cx),
@@ -75,7 +138,11 @@ impl<A: Accept, B: Accept> Accept for Either<A, B> {
}
}
impl<A: Accept> Accept for Option<A> {
fn poll_accept(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>> {
type Metadata = A::Metadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
match self {
None => Poll::Pending,
Some(a) => a.poll_accept(cx),
@@ -98,12 +165,15 @@ impl<A: Accept + Send + Sync + 'static> Acceptor<A> {
self.acceptor.poll_changed(cx)
}
fn poll_accept(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>> {
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(A::Metadata, AcceptStream), Error>> {
let _ = self.poll_changed(cx);
self.acceptor.peek_mut(|a| a.poll_accept(cx))
}
async fn accept(&mut self) -> Result<Accepted, Error> {
async fn accept(&mut self) -> Result<(A::Metadata, AcceptStream), Error> {
std::future::poll_fn(|cx| self.poll_accept(cx)).await
}
}
@@ -115,19 +185,14 @@ impl Acceptor<Vec<TcpListener>> {
}
}
pub type UpgradableListener =
Option<Either<SelfContainedNetworkInterfaceListener, NetworkInterfaceListener>>;
impl Acceptor<UpgradableListener> {
pub fn bind_upgradable(listener: SelfContainedNetworkInterfaceListener) -> Self {
Self::new(Some(Either::Left(listener)))
}
}
pub struct WebServerAcceptorSetter<A: Accept> {
acceptor: Watch<A>,
}
impl<A: Accept, B: Accept> WebServerAcceptorSetter<Option<Either<A, B>>> {
impl<A, B> WebServerAcceptorSetter<Option<Either<A, B>>>
where
A: Accept,
B: Accept<Metadata = A::Metadata>,
{
pub fn try_upgrade<F: FnOnce(A) -> Result<B, Error>>(&self, f: F) -> Result<(), Error> {
let mut res = Ok(());
self.acceptor.send_modify(|a| {
@@ -154,20 +219,24 @@ impl<A: Accept> Deref for WebServerAcceptorSetter<A> {
pub struct WebServer<A: Accept> {
shutdown: oneshot::Sender<()>,
router: Watch<Option<Router>>,
router: Watch<Router>,
acceptor: Watch<A>,
thread: NonDetachingJoinHandle<()>,
}
impl<A: Accept + Send + Sync + 'static> WebServer<A> {
impl<A> WebServer<A>
where
A: Accept + Send + Sync + 'static,
for<'a> A::Metadata: Visit<ExtensionVisitor<'a>> + Send + Sync + 'static,
{
pub fn acceptor_setter(&self) -> WebServerAcceptorSetter<A> {
WebServerAcceptorSetter {
acceptor: self.acceptor.clone(),
}
}
pub fn new(mut acceptor: Acceptor<A>) -> Self {
pub fn new(mut acceptor: Acceptor<A>, router: Router) -> Self {
let acceptor_send = acceptor.acceptor.clone();
let router = Watch::<Option<Router>>::new(None);
let router = Watch::new(router);
let service = router.clone_unseen();
let (shutdown, shutdown_recv) = oneshot::channel();
let thread = NonDetachingJoinHandle::from(tokio::spawn(async move {
@@ -190,13 +259,14 @@ impl<A: Accept + Send + Sync + 'static> WebServer<A> {
}
}
struct SwappableRouter {
router: Watch<Option<Router>>,
redirect: bool,
local_addr: SocketAddr,
peer_addr: SocketAddr,
struct SwappableRouter<M> {
router: Watch<Router>,
metadata: M,
}
impl hyper::service::Service<hyper::Request<hyper::body::Incoming>> for SwappableRouter {
impl<M: for<'a> Visit<ExtensionVisitor<'a>> + Send + Sync + 'static>
hyper::service::Service<hyper::Request<hyper::body::Incoming>>
for SwappableRouter<M>
{
type Response = <Router as tower_service::Service<
hyper::Request<hyper::body::Incoming>,
>>::Response;
@@ -210,19 +280,10 @@ impl<A: Accept + Send + Sync + 'static> WebServer<A> {
fn call(&self, mut req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
use tower_service::Service;
req.extensions_mut()
.insert(ConnectInfo((self.peer_addr, self.local_addr)));
self.metadata
.visit(&mut ExtensionVisitor(req.extensions_mut()));
if self.redirect {
redirecter().call(req)
} else {
let router = self.router.read();
if let Some(mut router) = router {
router.call(req)
} else {
refresher().call(req)
}
}
self.router.read().call(req)
}
}
@@ -249,18 +310,15 @@ impl<A: Accept + Send + Sync + 'static> WebServer<A> {
let mut err = None;
for _ in 0..5 {
if let Err(e) = async {
let accepted = acceptor.accept().await?;
let src = accepted.stream.peer_addr().ok();
let (metadata, stream) = acceptor.accept().await?;
queue.add_job(
graceful.watch(
server
.serve_connection_with_upgrades(
TokioIo::new(accepted.stream),
TokioIo::new(stream),
SwappableRouter {
router: service.clone(),
redirect: accepted.https_redirect,
peer_addr: accepted.peer_addr,
local_addr: accepted.local_addr,
metadata,
},
)
.into_owned(),
@@ -314,7 +372,7 @@ impl<A: Accept + Send + Sync + 'static> WebServer<A> {
}
pub fn serve_router(&mut self, router: Router) {
self.router.send(Some(router))
self.router.send(router)
}
pub fn serve_ui_for<C: UiContext>(&mut self, ctx: C) {

View File

@@ -3,8 +3,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use clap::Parser;
use imbl_value::InternedString;
use ipnet::Ipv4Net;
use models::GatewayId;
use rpc_toolkit::{Context, Empty, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerArgs, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use crate::context::CliContext;
@@ -13,11 +12,11 @@ use crate::prelude::*;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::db::GatewayPort;
use crate::tunnel::wg::{ClientConfig, WgConfig, WgSubnetClients, WgSubnetConfig};
use crate::util::serde::{HandlerExtSerde, display_serializable};
use crate::util::serde::{display_serializable, HandlerExtSerde};
pub fn tunnel_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand("web", super::web::web_api())
.subcommand("web", super::web::web_api::<C>())
.subcommand(
"db",
super::db::db_api::<C>()
@@ -165,9 +164,11 @@ pub async fn add_subnet(
.db
.mutate(|db| {
let map = db.as_wg_mut().as_subnets_mut();
if !map.contains_key(&subnet)? {
map.insert(&subnet, &WgSubnetConfig::new(name))?;
}
map.upsert(&subnet, || {
Ok(WgSubnetConfig::new(InternedString::default()))
})?
.as_name_mut()
.ser(&name)?;
db.as_wg().de()
})
.await
@@ -312,7 +313,8 @@ pub async fn show_config(
}: ShowConfigParams,
SubnetParams { subnet }: SubnetParams,
) -> Result<ClientConfig, Error> {
let wg = ctx.db.peek().await.into_wg();
let peek = ctx.db.peek().await;
let wg = peek.as_wg();
let client = wg
.as_subnets()
.as_idx(&subnet)
@@ -329,16 +331,19 @@ pub async fn show_config(
}
}) {
wan_addr
} else if let Some(webserver) = peek.as_webserver().de()? {
webserver.ip()
} else {
ctx.net_iface
.ip_info()
.into_iter()
.find_map(|(_, info)| {
info.public()
.then_some(info.ip_info)
.flatten()
.into_iter()
.find_map(|info| info.subnets.into_iter().next())
.peek(|i| {
i.iter().find_map(|(_, info)| {
info.ip_info
.as_ref()
.filter(|_| info.public())
.iter()
.find_map(|info| info.subnets.iter().next())
.copied()
})
})
.or_not_found("a public IP address")?
.addr()

View File

@@ -1,15 +1,13 @@
use std::net::IpAddr;
use clap::Parser;
use imbl::HashMap;
use imbl_value::{InternedString, json};
use imbl_value::InternedString;
use itertools::Itertools;
use patch_db::HasModel;
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use rpc_toolkit::{from_fn_async, Context, HandlerArgs, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use crate::auth::{Sessions, check_password};
use crate::auth::{check_password, Sessions};
use crate::context::CliContext;
use crate::middleware::auth::AuthContext;
use crate::middleware::signature::SignatureAuthContext;
@@ -18,7 +16,7 @@ use crate::rpc_continuations::OpenAuthedContinuations;
use crate::sign::AnyVerifyingKey;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::db::TunnelDatabase;
use crate::util::serde::{HandlerExtSerde, display_serializable};
use crate::util::serde::{display_serializable, HandlerExtSerde};
use crate::util::sync::SyncMutex;
impl SignatureAuthContext for TunnelContext {
@@ -31,14 +29,18 @@ impl SignatureAuthContext for TunnelContext {
async fn sig_context(
&self,
) -> impl IntoIterator<Item = Result<impl AsRef<str> + Send, Error>> + Send {
self.addrs
.iter()
.filter(|a| !match a {
IpAddr::V4(a) => a.is_loopback() || a.is_unspecified(),
IpAddr::V6(a) => a.is_loopback() || a.is_unspecified(),
})
.map(|a| InternedString::from_display(&a))
.map(Ok)
let peek = self.db().peek().await;
peek.as_webserver()
.de()
.map(|a| a.as_ref().map(InternedString::from_display))
.transpose()
.into_iter()
.chain(
std::iter::from_fn(move || Some(peek.as_certificates().keys()))
.flatten_ok()
.map_ok(|h| h.0)
.flatten_ok(),
)
}
fn check_pubkey(
db: &Model<Self::Database>,
@@ -77,7 +79,7 @@ impl AuthContext for TunnelContext {
&self.open_authed_continuations
}
fn check_password(db: &Model<Self::Database>, password: &str) -> Result<(), Error> {
check_password(&db.as_password().de()?, password)
check_password(&db.as_password().de()?.unwrap_or_default(), password)
}
}
@@ -204,7 +206,8 @@ pub async fn set_password_rpc(
password.as_bytes(),
&rand::random::<[u8; 16]>(),
&argon2::Config::rfc9106_low_mem(),
)?;
)
.with_kind(ErrorKind::PasswordHashGeneration)?;
ctx.db
.mutate(|db| db.as_password_mut().ser(&Some(pwhash)))
.await
@@ -234,7 +237,7 @@ pub async fn set_password_cli(
context
.call_remote::<TunnelContext>(
&parent_method.iter().chain(method.iter()).join("."),
to_value(SetPasswordParams { password }),
to_value(&SetPasswordParams { password })?,
)
.await?;

View File

@@ -1,5 +1,5 @@
use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::sync::Arc;
@@ -23,16 +23,16 @@ use url::Url;
use crate::auth::Sessions;
use crate::context::config::ContextConfig;
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::NetworkInterfaceType;
use crate::db::model::public::{NetworkInterfaceInfo, NetworkInterfaceType};
use crate::middleware::auth::AuthContext;
use crate::net::forward::PortForwardController;
use crate::net::gateway::{IdFilter, InterfaceFilter, NetworkInterfaceWatcher};
use crate::prelude::*;
use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations};
use crate::tunnel::db::{GatewayPort, TunnelDatabase};
use crate::tunnel::TUNNEL_DEFAULT_PORT;
use crate::tunnel::TUNNEL_DEFAULT_LISTEN;
use crate::util::io::read_file_to_string;
use crate::util::sync::SyncMutex;
use crate::util::sync::{SyncMutex, Watch};
use crate::util::Invoke;
#[derive(Debug, Clone, Default, Deserialize, Serialize, Parser)]
@@ -67,16 +67,14 @@ impl TunnelConfig {
pub struct TunnelContextSeed {
pub listen: SocketAddr,
pub addrs: BTreeSet<IpAddr>,
pub db: TypedPatchDb<TunnelDatabase>,
pub datadir: PathBuf,
pub rpc_continuations: RpcContinuations,
pub open_authed_continuations: OpenAuthedContinuations<Option<InternedString>>,
pub ephemeral_sessions: SyncMutex<Sessions>,
pub net_iface: NetworkInterfaceWatcher,
pub net_iface: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
pub forward: PortForwardController,
pub active_forwards: SyncMutex<BTreeMap<GatewayPort, Arc<()>>>,
pub masquerade_thread: NonDetachingJoinHandle<()>,
pub shutdown: Sender<()>,
}
@@ -101,12 +99,9 @@ impl TunnelContext {
|| async { Ok(Default::default()) },
)
.await?;
let listen = config.tunnel_listen.unwrap_or(SocketAddr::new(
Ipv6Addr::UNSPECIFIED.into(),
TUNNEL_DEFAULT_PORT,
));
let net_iface = NetworkInterfaceWatcher::new(async { OrdMap::new() }, []);
let forward = PortForwardController::new(net_iface.subscribe());
let listen = config.tunnel_listen.unwrap_or(TUNNEL_DEFAULT_LISTEN);
let net_iface = Watch::new(crate::net::utils::load_network_interface_info().await?);
let forward = PortForwardController::new(net_iface.clone_unseen());
Command::new("sysctl")
.arg("-w")
@@ -114,55 +109,45 @@ impl TunnelContext {
.invoke(ErrorKind::Network)
.await?;
let mut masquerade_net_iface = net_iface.subscribe();
let masquerade_thread = tokio::spawn(async move {
loop {
for iface in masquerade_net_iface.peek(|i| {
i.iter()
.filter(|(_, info)| {
dbg!(info).ip_info.as_ref().map_or(false, |i| {
dbg!(i).device_type != Some(NetworkInterfaceType::Wireguard)
})
})
.map(|(name, _)| name)
.cloned()
.collect::<Vec<_>>()
}) {
if Command::new("iptables")
.arg("-t")
.arg("nat")
.arg("-C")
.arg("POSTROUTING")
.arg("-o")
.arg(iface.as_str())
.arg("-j")
.arg("MASQUERADE")
.invoke(ErrorKind::Network)
.await
.is_err()
{
tracing::info!("Adding masquerade rule for interface {}", iface);
Command::new("iptables")
.arg("-t")
.arg("nat")
.arg("-A")
.arg("POSTROUTING")
.arg("-o")
.arg(iface.as_str())
.arg("-j")
.arg("MASQUERADE")
.invoke(ErrorKind::Network)
.await
.log_err();
}
}
masquerade_net_iface.changed().await;
tracing::info!("Network interfaces changed, updating masquerade rules");
for iface in net_iface.peek(|i| {
i.iter()
.filter(|(_, info)| {
dbg!(info).ip_info.as_ref().map_or(false, |i| {
dbg!(i).device_type != Some(NetworkInterfaceType::Wireguard)
})
})
.map(|(name, _)| name)
.cloned()
.collect::<Vec<_>>()
}) {
if Command::new("iptables")
.arg("-t")
.arg("nat")
.arg("-C")
.arg("POSTROUTING")
.arg("-o")
.arg(iface.as_str())
.arg("-j")
.arg("MASQUERADE")
.invoke(ErrorKind::Network)
.await
.is_err()
{
tracing::info!("Adding masquerade rule for interface {}", iface);
Command::new("iptables")
.arg("-t")
.arg("nat")
.arg("-A")
.arg("POSTROUTING")
.arg("-o")
.arg(iface.as_str())
.arg("-j")
.arg("MASQUERADE")
.invoke(ErrorKind::Network)
.await
.log_err();
}
})
.into();
}
let peek = db.peek().await;
peek.as_wg().de()?.sync().await?;
@@ -178,11 +163,6 @@ impl TunnelContext {
Ok(Self(Arc::new(TunnelContextSeed {
listen,
addrs: crate::net::utils::all_socket_addrs_for(listen.port())
.await?
.into_iter()
.map(|(_, a)| a.ip())
.collect(),
db,
datadir,
rpc_continuations: RpcContinuations::new(),
@@ -191,7 +171,6 @@ impl TunnelContext {
net_iface,
forward,
active_forwards: SyncMutex::new(active_forwards),
masquerade_thread,
shutdown,
})))
}
@@ -222,6 +201,14 @@ impl CallRemote<TunnelContext> for CliContext {
params: Value,
_: Empty,
) -> Result<Value, RpcError> {
let (tunnel_addr, addr_from_config) = if let Some(addr) = self.tunnel_addr {
(addr, true)
} else if let Some(addr) = self.tunnel_listen {
(addr, true)
} else {
(TUNNEL_DEFAULT_LISTEN, false)
};
let local =
if let Ok(local) = read_file_to_string(TunnelContext::LOCAL_AUTH_COOKIE_PATH).await {
self.cookie_store
@@ -229,11 +216,11 @@ impl CallRemote<TunnelContext> for CliContext {
.unwrap()
.insert_raw(
&Cookie::build(("local", local))
.domain("localhost")
.domain(&tunnel_addr.ip().to_string())
.expires(Expiration::Session)
.same_site(SameSite::Strict)
.build(),
&"http://localhost".parse()?,
&format!("http://{tunnel_addr}").parse()?,
)
.with_kind(crate::ErrorKind::Network)?;
true
@@ -241,24 +228,12 @@ impl CallRemote<TunnelContext> for CliContext {
false
};
let tunnel_addr = if let Some(addr) = self.tunnel_addr {
Some(addr)
} else if let Some(addr) = self.tunnel_listen {
Some(addr)
} else {
None
};
let (url, sig_ctx) = if let Some(tunnel_addr) = tunnel_addr {
let (url, sig_ctx) = if local && tunnel_addr.ip().is_loopback() {
(format!("http://{tunnel_addr}/rpc/v0").parse()?, None)
} else if addr_from_config {
(
format!("https://{tunnel_addr}/rpc/v0").parse()?,
Some(InternedString::from_display(
&self.tunnel_listen.unwrap_or(tunnel_addr).ip(),
)),
)
} else if local {
(
format!("http://localhost:{TUNNEL_DEFAULT_PORT}/rpc/v0").parse()?,
None,
Some(InternedString::from_display(&tunnel_addr.ip())),
)
} else {
return Err(Error::new(eyre!("`--tunnel` required"), ErrorKind::InvalidRequest).into());

View File

@@ -2,18 +2,18 @@ use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, SocketAddr, SocketAddrV4};
use std::path::PathBuf;
use clap::Parser;
use clap::builder::ValueParserFactory;
use clap::Parser;
use imbl::HashMap;
use imbl_value::InternedString;
use itertools::Itertools;
use models::{FromStrParser, GatewayId};
use openssl::pkey::{PKey, Private};
use openssl::x509::X509;
use patch_db::Dump;
use patch_db::json_ptr::{JsonPointer, ROOT};
use patch_db::Dump;
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use rpc_toolkit::{from_fn_async, Context, HandlerArgs, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use tracing::instrument;
use ts_rs::TS;
@@ -28,7 +28,7 @@ use crate::tunnel::context::TunnelContext;
use crate::tunnel::web::TunnelCertData;
use crate::tunnel::wg::WgServer;
use crate::util::serde::{
HandlerExtSerde, Pem, apply_expr, deserialize_from_str, serialize_display,
apply_expr, deserialize_from_str, serialize_display, HandlerExtSerde, Pem,
};
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
@@ -84,7 +84,7 @@ pub struct TunnelDatabase {
pub sessions: Sessions,
pub password: Option<String>,
pub auth_pubkeys: HashMap<AnyVerifyingKey, SignerInfo>,
pub certificate: BTreeMap<JsonKey<BTreeSet<InternedString>>, TunnelCertData>,
pub certificates: BTreeMap<JsonKey<BTreeSet<InternedString>>, TunnelCertData>,
pub wg: WgServer,
pub port_forwards: PortForwards,
}

View File

@@ -1,3 +1,5 @@
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use axum::Router;
use futures::future::ready;
use rpc_toolkit::Server;
@@ -17,6 +19,10 @@ pub mod web;
pub mod wg;
pub const TUNNEL_DEFAULT_PORT: u16 = 5960;
pub const TUNNEL_DEFAULT_LISTEN: SocketAddr = SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(127, 0, 59, 60),
TUNNEL_DEFAULT_PORT,
));
pub fn tunnel_router(ctx: TunnelContext) -> Router {
use axum::extract as x;

View File

@@ -1,24 +1,22 @@
use std::{
collections::BTreeSet,
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
};
use std::collections::{BTreeSet, VecDeque};
use std::io::Write;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use crate::{
context::CliContext, net::ssl::SANInfo, prelude::*, tunnel::context::TunnelContext,
util::serde::Pem,
};
use clap::Parser;
use futures::AsyncWriteExt;
use imbl_value::{InternedString, json};
use imbl_value::{json, InternedString};
use itertools::Itertools;
use openssl::{
pkey::{PKey, Private},
x509::{GeneralName, X509},
};
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use openssl::pkey::{PKey, Private};
use openssl::x509::X509;
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerArgs, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use crate::context::CliContext;
use crate::net::ssl::SANInfo;
use crate::prelude::*;
use crate::tunnel::context::TunnelContext;
use crate::util::serde::Pem;
#[derive(Debug, Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct TunnelCertData {
@@ -39,11 +37,18 @@ pub fn web_api<C: Context>() -> ParentHandler<C> {
)
.subcommand(
"import-certificate",
from_fn_async(set_certificate)
from_fn_async(import_certificate)
.with_about("Import a certificate to use for the webserver")
.no_display()
.with_call_remote::<CliContext>(),
)
// .subcommand(
// "forget-certificate",
// from_fn_async(forget_certificate)
// .with_about("Forget a certificate that was imported into the webserver")
// .no_display()
// .with_call_remote::<CliContext>(),
// )
.subcommand(
"uninit",
from_fn_async(uninit_web)
@@ -53,7 +58,10 @@ pub fn web_api<C: Context>() -> ParentHandler<C> {
)
}
pub async fn set_certificate(ctx: TunnelContext, cert_data: TunnelCertData) -> Result<(), Error> {
pub async fn import_certificate(
ctx: TunnelContext,
cert_data: TunnelCertData,
) -> Result<(), Error> {
let mut saninfo = BTreeSet::new();
let leaf = cert_data.cert.get(0).ok_or_else(|| {
Error::new(
@@ -66,18 +74,20 @@ pub async fn set_certificate(ctx: TunnelContext, cert_data: TunnelCertData) -> R
saninfo.insert(dns.into());
}
if let Some(ip) = san.ipaddress() {
if ip.len() == 4 {
saninfo.insert(InternedString::from_display(&Ipv4Addr::new(
ip[0], ip[1], ip[2], ip[3],
if let Ok::<[u8; 4], _>(ip) = ip.try_into() {
saninfo.insert(InternedString::from_display(&Ipv4Addr::from_bits(
u32::from_be_bytes(ip),
)));
} else if let Ok::<[u8; 16], _>(ip) = ip.try_into() {
saninfo.insert(InternedString::from_display(&Ipv6Addr::from_bits(
u128::from_be_bytes(ip),
)));
} else if ip.len() == 16 {
saninfo.insert(InternedString::from_display(&Ipv6Addr::from_bits(bits)))
}
}
}
ctx.db
.mutate(|db| {
db.as_certificate_mut()
db.as_certificates_mut()
.insert(&JsonKey(saninfo), &cert_data)
})
.await
@@ -85,7 +95,7 @@ pub async fn set_certificate(ctx: TunnelContext, cert_data: TunnelCertData) -> R
Ok(())
}
#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Deserialize, Serialize, Parser)]
pub struct InitWebParams {
listen: SocketAddr,
}
@@ -96,7 +106,7 @@ pub async fn init_web_rpc(
) -> Result<(), Error> {
ctx.db
.mutate(|db| {
if db.as_certificate().de()?.is_empty() {
if db.as_certificates().de()?.is_empty() {
return Err(Error::new(
eyre!("No certificate available"),
ErrorKind::OpenSsl,
@@ -119,7 +129,7 @@ pub async fn init_web_rpc(
pub async fn uninit_web(ctx: TunnelContext) -> Result<(), Error> {
ctx.db
.mutate(|db| db.as_webserver_mut().ser(&false))
.mutate(|db| db.as_webserver_mut().ser(&None))
.await
.result
}
@@ -129,18 +139,19 @@ pub async fn init_web_cli(
context,
parent_method,
method,
params: InitWebParams { listen },
..
}: HandlerArgs<CliContext>,
}: HandlerArgs<CliContext, InitWebParams>,
) -> Result<(), Error> {
loop {
match context
.call_remote(
.call_remote::<TunnelContext>(
&parent_method.iter().chain(method.iter()).join("."),
json!({}),
to_value(&InitWebParams { listen })?,
)
.await
{
Ok(a) => println!("Webserver Initialized"),
Ok(_) => println!("Webserver Initialized"),
Err(e) if e.code == ErrorKind::OpenSsl as i32 => {
println!(
"StartTunnel has not been set up with an SSL Certificate yet. Setting one up now..."
@@ -155,11 +166,15 @@ pub async fn init_web_cli(
let self_signed;
loop {
match readline.readline().await.with_kind(ErrorKind::Filesystem)? {
rustyline_async::ReadlineEvent::Line(l) if l.trim() == "1" => {
rustyline_async::ReadlineEvent::Line(l)
if l.trim_matches(|c: char| c.is_whitespace() || c == '"') == "1" =>
{
self_signed = true;
break;
}
rustyline_async::ReadlineEvent::Line(l) if l.trim() == "2" => {
rustyline_async::ReadlineEvent::Line(l)
if l.trim_matches(|c: char| c.is_whitespace() || c == '"') == "2" =>
{
self_signed = false;
break;
}
@@ -167,21 +182,47 @@ pub async fn init_web_cli(
readline.clear_history();
readline.add_history_entry("1".into());
readline.add_history_entry("2".into());
writer
.write_all(b"Invalid response. Enter either \"1\" or \"2\".")
.await?;
writeln!(writer, "Invalid response. Enter either \"1\" or \"2\".")?;
}
_ => return Err(Error::new(eyre!("Aborted"), ErrorKind::Unknown)),
}
}
drop((readline, writer));
if self_signed {
writeln!(
writer,
"Enter the name(s) to sign the certificate for, separated by commas."
)?;
readline.clear_history();
readline
.update_prompt(&format!("Subject Alternative Name(s) [{}]: ", listen.ip()))
.with_kind(ErrorKind::Filesystem)?;
let mut saninfo = BTreeSet::new();
loop {
match readline.readline().await.with_kind(ErrorKind::Filesystem)? {
rustyline_async::ReadlineEvent::Line(l) if !l.trim().is_empty() => {
saninfo.extend(l.split(",").map(|h| h.trim().into()));
break;
}
rustyline_async::ReadlineEvent::Line(_) => {
readline.clear_history();
}
_ => return Err(Error::new(eyre!("Aborted"), ErrorKind::Unknown)),
}
}
let key = crate::net::ssl::gen_nistp256()?;
let cert = crate::net::ssl::make_self_signed((
&key,
&SANInfo::new(&[].into_iter().collect()),
))?;
let cert = crate::net::ssl::make_self_signed((&key, &SANInfo::new(&saninfo)))?;
context
.call_remote::<TunnelContext>(
"web.import-certificate",
to_value(&TunnelCertData {
key: Pem(key),
cert: Pem(vec![cert]),
})?,
)
.await?;
} else {
drop((readline, writer));
println!("Please paste in your PEM encoded private key: ");
let mut stdin_lines = BufReader::new(tokio::io::stdin()).lines();
let mut key_string = String::new();
@@ -225,7 +266,7 @@ pub async fn init_web_cli(
}
context
.call_remote(
.call_remote::<TunnelContext>(
"web.import-certificate",
to_value(&TunnelCertData {
key,
@@ -235,7 +276,19 @@ pub async fn init_web_cli(
.await?;
}
}
Err(e) if e.code == ErrorKind::Authorization as i32 => {}
Err(e) if e.code == ErrorKind::Authorization as i32 => {
println!("A password has not been setup yet. Setting one up now...");
super::auth::set_password_cli(HandlerArgs {
context: context.clone(),
parent_method: vec!["auth", "set-password"].into(),
method: VecDeque::new(),
params: Empty {},
inherited_params: Empty {},
raw_params: json!({}),
})
.await?;
}
Err(e) => return Err(e.into()),
}
}

View File

@@ -1201,7 +1201,7 @@ impl PemEncoding for X509 {
impl PemEncoding for Vec<X509> {
fn from_pem<E: serde::de::Error>(pem: &str) -> Result<Self, E> {
X509::stack_from_pem(pem).map_err(E::custom)
X509::stack_from_pem(pem.as_bytes()).map_err(E::custom)
}
fn to_pem<E: serde::ser::Error>(&self) -> Result<String, E> {
self.iter()