mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-03-26 02:11:53 +00:00
fix: use shared futures for ACME cert acquisition with 2m timeout
This commit is contained in:
@@ -2,11 +2,13 @@ use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_acme::acme::{ACME_TLS_ALPN_NAME, Identifier};
|
||||
use clap::Parser;
|
||||
use clap::builder::ValueParserFactory;
|
||||
use futures::StreamExt;
|
||||
use futures::future::{BoxFuture, Shared};
|
||||
use futures::{FutureExt, StreamExt};
|
||||
use imbl_value::InternedString;
|
||||
use itertools::Itertools;
|
||||
use openssl::pkey::{PKey, Private};
|
||||
@@ -42,7 +44,8 @@ pub struct AcmeTlsHandler<M: HasModel, S> {
|
||||
pub acme_cache: AcmeTlsAlpnCache,
|
||||
pub crypto_provider: Arc<CryptoProvider>,
|
||||
pub get_provider: S,
|
||||
pub in_progress: Watch<BTreeSet<BTreeSet<InternedString>>>,
|
||||
pub in_progress:
|
||||
Watch<BTreeMap<BTreeSet<InternedString>, Shared<BoxFuture<'static, Option<CertifiedKey>>>>>,
|
||||
}
|
||||
impl<M, S> AcmeTlsHandler<M, S>
|
||||
where
|
||||
@@ -50,8 +53,9 @@ where
|
||||
+ DbAccessMut<AcmeCertStore>
|
||||
+ HasModel<Model = Model<M>>
|
||||
+ Send
|
||||
+ Sync,
|
||||
S: GetAcmeProvider + Clone,
|
||||
+ Sync
|
||||
+ 'static,
|
||||
S: GetAcmeProvider + Clone + Send + Sync,
|
||||
{
|
||||
pub async fn get_cert(&self, san_info: &BTreeSet<InternedString>) -> Option<CertifiedKey> {
|
||||
let provider = self.get_provider.get_provider(san_info).await?;
|
||||
@@ -71,6 +75,8 @@ where
|
||||
.and_then(|c| should_use_cert(&c.0).log_err())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
self.in_progress
|
||||
.send_if_modified(|map| map.remove(san_info).is_some());
|
||||
return Some(
|
||||
CertifiedKey::from_der(
|
||||
cert.fullchain
|
||||
@@ -88,21 +94,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
if !self.in_progress.send_if_modified(|x| {
|
||||
if !x.contains(san_info) {
|
||||
x.insert(san_info.clone());
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}) {
|
||||
self.in_progress
|
||||
.clone()
|
||||
.wait_for(|x| !x.contains(san_info))
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let contact = <M as DbAccessByKey<AcmeSettings>>::access_by_key(&peek, &provider)?
|
||||
.as_contact()
|
||||
.de()
|
||||
@@ -121,11 +112,26 @@ where
|
||||
.cloned()
|
||||
.map(|d| (d, Watch::new(None)))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
self.acme_cache.mutate(|c| {
|
||||
|
||||
let cert = self
|
||||
.in_progress
|
||||
.send_modify(|map| {
|
||||
if let Some(fut) = map.get(san_info).cloned() {
|
||||
if fut.peek().map_or(true, |f| f.is_some()) {
|
||||
return fut;
|
||||
}
|
||||
}
|
||||
let provider = provider.clone();
|
||||
let acme_cache = self.acme_cache.clone();
|
||||
let db = self.db.clone();
|
||||
let fut = async move {
|
||||
acme_cache.mutate(|c| {
|
||||
c.extend(cache_entries.iter().map(|(k, v)| (k.clone(), v.clone())));
|
||||
});
|
||||
|
||||
let cert = async_acme::rustls_helper::order(
|
||||
let cert = tokio::time::timeout(
|
||||
Duration::from_secs(120),
|
||||
async_acme::rustls_helper::order(
|
||||
|identifier, cert| {
|
||||
let domain = InternedString::from_display(&identifier);
|
||||
if let Some(entry) = cache_entries.get(&domain) {
|
||||
@@ -135,17 +141,24 @@ where
|
||||
},
|
||||
provider.0.as_str(),
|
||||
&identifiers,
|
||||
Some(&AcmeCertCache(&self.db)),
|
||||
Some(&AcmeCertCache(&db)),
|
||||
&contact,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.log_err()?
|
||||
.log_err()?;
|
||||
|
||||
self.acme_cache
|
||||
.mutate(|c| c.retain(|c, _| !cache_entries.contains_key(c)));
|
||||
|
||||
self.in_progress.send_modify(|i| i.remove(san_info));
|
||||
acme_cache.mutate(|c| c.retain(|c, _| !cache_entries.contains_key(c)));
|
||||
|
||||
Some(cert)
|
||||
}
|
||||
.boxed()
|
||||
.shared();
|
||||
map.insert(san_info.clone(), fut.clone());
|
||||
fut
|
||||
})
|
||||
.await?;
|
||||
return Some(cert);
|
||||
}
|
||||
}
|
||||
@@ -166,7 +179,8 @@ where
|
||||
+ DbAccessMut<AcmeCertStore>
|
||||
+ HasModel<Model = Model<M>>
|
||||
+ Send
|
||||
+ Sync,
|
||||
+ Sync
|
||||
+ 'static,
|
||||
S: GetAcmeProvider + Clone + Send + Sync,
|
||||
{
|
||||
async fn get_config(
|
||||
@@ -462,6 +476,7 @@ impl ValueParserFactory for AcmeProvider {
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Parser, TS)]
|
||||
#[group(skip)]
|
||||
#[ts(export)]
|
||||
pub struct InitAcmeParams {
|
||||
#[arg(long, help = "help.arg.acme-provider")]
|
||||
@@ -488,6 +503,7 @@ pub async fn init(
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Parser, TS)]
|
||||
#[group(skip)]
|
||||
#[ts(export)]
|
||||
pub struct RemoveAcmeParams {
|
||||
#[arg(long, help = "help.arg.acme-provider")]
|
||||
|
||||
@@ -64,6 +64,7 @@ pub struct PassthroughInfo {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
|
||||
#[group(skip)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
struct AddPassthroughParams {
|
||||
#[arg(long)]
|
||||
@@ -79,6 +80,7 @@ struct AddPassthroughParams {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
|
||||
#[group(skip)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
struct RemovePassthroughParams {
|
||||
#[arg(long)]
|
||||
@@ -959,7 +961,8 @@ where
|
||||
+ DbAccessMut<AcmeCertStore>
|
||||
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
|
||||
+ Send
|
||||
+ Sync,
|
||||
+ Sync
|
||||
+ 'static,
|
||||
A: Accept + 'static,
|
||||
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
|
||||
+ Visit<ExtractVisitor<GatewayInfo>>
|
||||
@@ -1088,7 +1091,7 @@ impl<A: Accept> VHostServer<A> {
|
||||
acme_cache,
|
||||
crypto_provider: crypto_provider.clone(),
|
||||
get_provider: GetVHostAcmeProvider(mapping.clone()),
|
||||
in_progress: Watch::new(BTreeSet::new()),
|
||||
in_progress: Watch::new(BTreeMap::new()),
|
||||
}),
|
||||
RootCaTlsHandler {
|
||||
db,
|
||||
|
||||
Reference in New Issue
Block a user