fix: use shared futures for ACME cert acquisition with 2m timeout

This commit is contained in:
Aiden McClelland
2026-03-19 00:06:34 -06:00
parent 9a58568053
commit 292a914307
2 changed files with 61 additions and 42 deletions

View File

@@ -2,11 +2,13 @@ use std::collections::{BTreeMap, BTreeSet};
use std::net::IpAddr;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use async_acme::acme::{ACME_TLS_ALPN_NAME, Identifier};
use clap::Parser;
use clap::builder::ValueParserFactory;
use futures::StreamExt;
use futures::future::{BoxFuture, Shared};
use futures::{FutureExt, StreamExt};
use imbl_value::InternedString;
use itertools::Itertools;
use openssl::pkey::{PKey, Private};
@@ -42,7 +44,8 @@ pub struct AcmeTlsHandler<M: HasModel, S> {
pub acme_cache: AcmeTlsAlpnCache,
pub crypto_provider: Arc<CryptoProvider>,
pub get_provider: S,
pub in_progress: Watch<BTreeSet<BTreeSet<InternedString>>>,
pub in_progress:
Watch<BTreeMap<BTreeSet<InternedString>, Shared<BoxFuture<'static, Option<CertifiedKey>>>>>,
}
impl<M, S> AcmeTlsHandler<M, S>
where
@@ -50,8 +53,9 @@ where
+ DbAccessMut<AcmeCertStore>
+ HasModel<Model = Model<M>>
+ Send
+ Sync,
S: GetAcmeProvider + Clone,
+ Sync
+ 'static,
S: GetAcmeProvider + Clone + Send + Sync,
{
pub async fn get_cert(&self, san_info: &BTreeSet<InternedString>) -> Option<CertifiedKey> {
let provider = self.get_provider.get_provider(san_info).await?;
@@ -71,6 +75,8 @@ where
.and_then(|c| should_use_cert(&c.0).log_err())
.unwrap_or(false)
{
self.in_progress
.send_if_modified(|map| map.remove(san_info).is_some());
return Some(
CertifiedKey::from_der(
cert.fullchain
@@ -88,21 +94,6 @@ where
}
}
if !self.in_progress.send_if_modified(|x| {
if !x.contains(san_info) {
x.insert(san_info.clone());
true
} else {
false
}
}) {
self.in_progress
.clone()
.wait_for(|x| !x.contains(san_info))
.await;
continue;
}
let contact = <M as DbAccessByKey<AcmeSettings>>::access_by_key(&peek, &provider)?
.as_contact()
.de()
@@ -121,11 +112,26 @@ where
.cloned()
.map(|d| (d, Watch::new(None)))
.collect::<BTreeMap<_, _>>();
self.acme_cache.mutate(|c| {
let cert = self
.in_progress
.send_modify(|map| {
if let Some(fut) = map.get(san_info).cloned() {
if fut.peek().map_or(true, |f| f.is_some()) {
return fut;
}
}
let provider = provider.clone();
let acme_cache = self.acme_cache.clone();
let db = self.db.clone();
let fut = async move {
acme_cache.mutate(|c| {
c.extend(cache_entries.iter().map(|(k, v)| (k.clone(), v.clone())));
});
let cert = async_acme::rustls_helper::order(
let cert = tokio::time::timeout(
Duration::from_secs(120),
async_acme::rustls_helper::order(
|identifier, cert| {
let domain = InternedString::from_display(&identifier);
if let Some(entry) = cache_entries.get(&domain) {
@@ -135,17 +141,24 @@ where
},
provider.0.as_str(),
&identifiers,
Some(&AcmeCertCache(&self.db)),
Some(&AcmeCertCache(&db)),
&contact,
),
)
.await
.log_err()?
.log_err()?;
self.acme_cache
.mutate(|c| c.retain(|c, _| !cache_entries.contains_key(c)));
self.in_progress.send_modify(|i| i.remove(san_info));
acme_cache.mutate(|c| c.retain(|c, _| !cache_entries.contains_key(c)));
Some(cert)
}
.boxed()
.shared();
map.insert(san_info.clone(), fut.clone());
fut
})
.await?;
return Some(cert);
}
}
@@ -166,7 +179,8 @@ where
+ DbAccessMut<AcmeCertStore>
+ HasModel<Model = Model<M>>
+ Send
+ Sync,
+ Sync
+ 'static,
S: GetAcmeProvider + Clone + Send + Sync,
{
async fn get_config(
@@ -462,6 +476,7 @@ impl ValueParserFactory for AcmeProvider {
}
#[derive(Deserialize, Serialize, Parser, TS)]
#[group(skip)]
#[ts(export)]
pub struct InitAcmeParams {
#[arg(long, help = "help.arg.acme-provider")]
@@ -488,6 +503,7 @@ pub async fn init(
}
#[derive(Deserialize, Serialize, Parser, TS)]
#[group(skip)]
#[ts(export)]
pub struct RemoveAcmeParams {
#[arg(long, help = "help.arg.acme-provider")]

View File

@@ -64,6 +64,7 @@ pub struct PassthroughInfo {
}
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
#[group(skip)]
#[serde(rename_all = "kebab-case")]
struct AddPassthroughParams {
#[arg(long)]
@@ -79,6 +80,7 @@ struct AddPassthroughParams {
}
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
#[group(skip)]
#[serde(rename_all = "kebab-case")]
struct RemovePassthroughParams {
#[arg(long)]
@@ -959,7 +961,8 @@ where
+ DbAccessMut<AcmeCertStore>
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ Send
+ Sync,
+ Sync
+ 'static,
A: Accept + 'static,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
@@ -1088,7 +1091,7 @@ impl<A: Accept> VHostServer<A> {
acme_cache,
crypto_provider: crypto_provider.clone(),
get_provider: GetVHostAcmeProvider(mapping.clone()),
in_progress: Watch::new(BTreeSet::new()),
in_progress: Watch::new(BTreeMap::new()),
}),
RootCaTlsHandler {
db,