fix: use shared futures for ACME cert acquisition with 2m timeout

This commit is contained in:
Aiden McClelland
2026-03-19 00:06:34 -06:00
parent 9a58568053
commit 292a914307
2 changed files with 61 additions and 42 deletions

View File

@@ -2,11 +2,13 @@ use std::collections::{BTreeMap, BTreeSet};
use std::net::IpAddr; use std::net::IpAddr;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use async_acme::acme::{ACME_TLS_ALPN_NAME, Identifier}; use async_acme::acme::{ACME_TLS_ALPN_NAME, Identifier};
use clap::Parser; use clap::Parser;
use clap::builder::ValueParserFactory; use clap::builder::ValueParserFactory;
use futures::StreamExt; use futures::future::{BoxFuture, Shared};
use futures::{FutureExt, StreamExt};
use imbl_value::InternedString; use imbl_value::InternedString;
use itertools::Itertools; use itertools::Itertools;
use openssl::pkey::{PKey, Private}; use openssl::pkey::{PKey, Private};
@@ -42,7 +44,8 @@ pub struct AcmeTlsHandler<M: HasModel, S> {
pub acme_cache: AcmeTlsAlpnCache, pub acme_cache: AcmeTlsAlpnCache,
pub crypto_provider: Arc<CryptoProvider>, pub crypto_provider: Arc<CryptoProvider>,
pub get_provider: S, pub get_provider: S,
pub in_progress: Watch<BTreeSet<BTreeSet<InternedString>>>, pub in_progress:
Watch<BTreeMap<BTreeSet<InternedString>, Shared<BoxFuture<'static, Option<CertifiedKey>>>>>,
} }
impl<M, S> AcmeTlsHandler<M, S> impl<M, S> AcmeTlsHandler<M, S>
where where
@@ -50,8 +53,9 @@ where
+ DbAccessMut<AcmeCertStore> + DbAccessMut<AcmeCertStore>
+ HasModel<Model = Model<M>> + HasModel<Model = Model<M>>
+ Send + Send
+ Sync, + Sync
S: GetAcmeProvider + Clone, + 'static,
S: GetAcmeProvider + Clone + Send + Sync,
{ {
pub async fn get_cert(&self, san_info: &BTreeSet<InternedString>) -> Option<CertifiedKey> { pub async fn get_cert(&self, san_info: &BTreeSet<InternedString>) -> Option<CertifiedKey> {
let provider = self.get_provider.get_provider(san_info).await?; let provider = self.get_provider.get_provider(san_info).await?;
@@ -71,6 +75,8 @@ where
.and_then(|c| should_use_cert(&c.0).log_err()) .and_then(|c| should_use_cert(&c.0).log_err())
.unwrap_or(false) .unwrap_or(false)
{ {
self.in_progress
.send_if_modified(|map| map.remove(san_info).is_some());
return Some( return Some(
CertifiedKey::from_der( CertifiedKey::from_der(
cert.fullchain cert.fullchain
@@ -88,21 +94,6 @@ where
} }
} }
if !self.in_progress.send_if_modified(|x| {
if !x.contains(san_info) {
x.insert(san_info.clone());
true
} else {
false
}
}) {
self.in_progress
.clone()
.wait_for(|x| !x.contains(san_info))
.await;
continue;
}
let contact = <M as DbAccessByKey<AcmeSettings>>::access_by_key(&peek, &provider)? let contact = <M as DbAccessByKey<AcmeSettings>>::access_by_key(&peek, &provider)?
.as_contact() .as_contact()
.de() .de()
@@ -121,31 +112,53 @@ where
.cloned() .cloned()
.map(|d| (d, Watch::new(None))) .map(|d| (d, Watch::new(None)))
.collect::<BTreeMap<_, _>>(); .collect::<BTreeMap<_, _>>();
self.acme_cache.mutate(|c| {
c.extend(cache_entries.iter().map(|(k, v)| (k.clone(), v.clone())));
});
let cert = async_acme::rustls_helper::order( let cert = self
|identifier, cert| { .in_progress
let domain = InternedString::from_display(&identifier); .send_modify(|map| {
if let Some(entry) = cache_entries.get(&domain) { if let Some(fut) = map.get(san_info).cloned() {
entry.send(Some(Arc::new(cert))); if fut.peek().map_or(true, |f| f.is_some()) {
return fut;
}
} }
Ok(()) let provider = provider.clone();
}, let acme_cache = self.acme_cache.clone();
provider.0.as_str(), let db = self.db.clone();
&identifiers, let fut = async move {
Some(&AcmeCertCache(&self.db)), acme_cache.mutate(|c| {
&contact, c.extend(cache_entries.iter().map(|(k, v)| (k.clone(), v.clone())));
) });
.await
.log_err()?;
self.acme_cache let cert = tokio::time::timeout(
.mutate(|c| c.retain(|c, _| !cache_entries.contains_key(c))); Duration::from_secs(120),
async_acme::rustls_helper::order(
|identifier, cert| {
let domain = InternedString::from_display(&identifier);
if let Some(entry) = cache_entries.get(&domain) {
entry.send(Some(Arc::new(cert)));
}
Ok(())
},
provider.0.as_str(),
&identifiers,
Some(&AcmeCertCache(&db)),
&contact,
),
)
.await
.log_err()?
.log_err()?;
self.in_progress.send_modify(|i| i.remove(san_info)); acme_cache.mutate(|c| c.retain(|c, _| !cache_entries.contains_key(c)));
Some(cert)
}
.boxed()
.shared();
map.insert(san_info.clone(), fut.clone());
fut
})
.await?;
return Some(cert); return Some(cert);
} }
} }
@@ -166,7 +179,8 @@ where
+ DbAccessMut<AcmeCertStore> + DbAccessMut<AcmeCertStore>
+ HasModel<Model = Model<M>> + HasModel<Model = Model<M>>
+ Send + Send
+ Sync, + Sync
+ 'static,
S: GetAcmeProvider + Clone + Send + Sync, S: GetAcmeProvider + Clone + Send + Sync,
{ {
async fn get_config( async fn get_config(
@@ -462,6 +476,7 @@ impl ValueParserFactory for AcmeProvider {
} }
#[derive(Deserialize, Serialize, Parser, TS)] #[derive(Deserialize, Serialize, Parser, TS)]
#[group(skip)]
#[ts(export)] #[ts(export)]
pub struct InitAcmeParams { pub struct InitAcmeParams {
#[arg(long, help = "help.arg.acme-provider")] #[arg(long, help = "help.arg.acme-provider")]
@@ -488,6 +503,7 @@ pub async fn init(
} }
#[derive(Deserialize, Serialize, Parser, TS)] #[derive(Deserialize, Serialize, Parser, TS)]
#[group(skip)]
#[ts(export)] #[ts(export)]
pub struct RemoveAcmeParams { pub struct RemoveAcmeParams {
#[arg(long, help = "help.arg.acme-provider")] #[arg(long, help = "help.arg.acme-provider")]

View File

@@ -64,6 +64,7 @@ pub struct PassthroughInfo {
} }
#[derive(Debug, Clone, Deserialize, Serialize, Parser)] #[derive(Debug, Clone, Deserialize, Serialize, Parser)]
#[group(skip)]
#[serde(rename_all = "kebab-case")] #[serde(rename_all = "kebab-case")]
struct AddPassthroughParams { struct AddPassthroughParams {
#[arg(long)] #[arg(long)]
@@ -79,6 +80,7 @@ struct AddPassthroughParams {
} }
#[derive(Debug, Clone, Deserialize, Serialize, Parser)] #[derive(Debug, Clone, Deserialize, Serialize, Parser)]
#[group(skip)]
#[serde(rename_all = "kebab-case")] #[serde(rename_all = "kebab-case")]
struct RemovePassthroughParams { struct RemovePassthroughParams {
#[arg(long)] #[arg(long)]
@@ -959,7 +961,8 @@ where
+ DbAccessMut<AcmeCertStore> + DbAccessMut<AcmeCertStore>
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider> + DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ Send + Send
+ Sync, + Sync
+ 'static,
A: Accept + 'static, A: Accept + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>> <A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>> + Visit<ExtractVisitor<GatewayInfo>>
@@ -1088,7 +1091,7 @@ impl<A: Accept> VHostServer<A> {
acme_cache, acme_cache,
crypto_provider: crypto_provider.clone(), crypto_provider: crypto_provider.clone(),
get_provider: GetVHostAcmeProvider(mapping.clone()), get_provider: GetVHostAcmeProvider(mapping.clone()),
in_progress: Watch::new(BTreeSet::new()), in_progress: Watch::new(BTreeMap::new()),
}), }),
RootCaTlsHandler { RootCaTlsHandler {
db, db,