mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-03-26 02:11:53 +00:00
fix: use shared futures for ACME cert acquisition with 2m timeout
This commit is contained in:
@@ -2,11 +2,13 @@ use std::collections::{BTreeMap, BTreeSet};
|
|||||||
use std::net::IpAddr;
|
use std::net::IpAddr;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use async_acme::acme::{ACME_TLS_ALPN_NAME, Identifier};
|
use async_acme::acme::{ACME_TLS_ALPN_NAME, Identifier};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use clap::builder::ValueParserFactory;
|
use clap::builder::ValueParserFactory;
|
||||||
use futures::StreamExt;
|
use futures::future::{BoxFuture, Shared};
|
||||||
|
use futures::{FutureExt, StreamExt};
|
||||||
use imbl_value::InternedString;
|
use imbl_value::InternedString;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use openssl::pkey::{PKey, Private};
|
use openssl::pkey::{PKey, Private};
|
||||||
@@ -42,7 +44,8 @@ pub struct AcmeTlsHandler<M: HasModel, S> {
|
|||||||
pub acme_cache: AcmeTlsAlpnCache,
|
pub acme_cache: AcmeTlsAlpnCache,
|
||||||
pub crypto_provider: Arc<CryptoProvider>,
|
pub crypto_provider: Arc<CryptoProvider>,
|
||||||
pub get_provider: S,
|
pub get_provider: S,
|
||||||
pub in_progress: Watch<BTreeSet<BTreeSet<InternedString>>>,
|
pub in_progress:
|
||||||
|
Watch<BTreeMap<BTreeSet<InternedString>, Shared<BoxFuture<'static, Option<CertifiedKey>>>>>,
|
||||||
}
|
}
|
||||||
impl<M, S> AcmeTlsHandler<M, S>
|
impl<M, S> AcmeTlsHandler<M, S>
|
||||||
where
|
where
|
||||||
@@ -50,8 +53,9 @@ where
|
|||||||
+ DbAccessMut<AcmeCertStore>
|
+ DbAccessMut<AcmeCertStore>
|
||||||
+ HasModel<Model = Model<M>>
|
+ HasModel<Model = Model<M>>
|
||||||
+ Send
|
+ Send
|
||||||
+ Sync,
|
+ Sync
|
||||||
S: GetAcmeProvider + Clone,
|
+ 'static,
|
||||||
|
S: GetAcmeProvider + Clone + Send + Sync,
|
||||||
{
|
{
|
||||||
pub async fn get_cert(&self, san_info: &BTreeSet<InternedString>) -> Option<CertifiedKey> {
|
pub async fn get_cert(&self, san_info: &BTreeSet<InternedString>) -> Option<CertifiedKey> {
|
||||||
let provider = self.get_provider.get_provider(san_info).await?;
|
let provider = self.get_provider.get_provider(san_info).await?;
|
||||||
@@ -71,6 +75,8 @@ where
|
|||||||
.and_then(|c| should_use_cert(&c.0).log_err())
|
.and_then(|c| should_use_cert(&c.0).log_err())
|
||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
{
|
{
|
||||||
|
self.in_progress
|
||||||
|
.send_if_modified(|map| map.remove(san_info).is_some());
|
||||||
return Some(
|
return Some(
|
||||||
CertifiedKey::from_der(
|
CertifiedKey::from_der(
|
||||||
cert.fullchain
|
cert.fullchain
|
||||||
@@ -88,21 +94,6 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !self.in_progress.send_if_modified(|x| {
|
|
||||||
if !x.contains(san_info) {
|
|
||||||
x.insert(san_info.clone());
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}) {
|
|
||||||
self.in_progress
|
|
||||||
.clone()
|
|
||||||
.wait_for(|x| !x.contains(san_info))
|
|
||||||
.await;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let contact = <M as DbAccessByKey<AcmeSettings>>::access_by_key(&peek, &provider)?
|
let contact = <M as DbAccessByKey<AcmeSettings>>::access_by_key(&peek, &provider)?
|
||||||
.as_contact()
|
.as_contact()
|
||||||
.de()
|
.de()
|
||||||
@@ -121,31 +112,53 @@ where
|
|||||||
.cloned()
|
.cloned()
|
||||||
.map(|d| (d, Watch::new(None)))
|
.map(|d| (d, Watch::new(None)))
|
||||||
.collect::<BTreeMap<_, _>>();
|
.collect::<BTreeMap<_, _>>();
|
||||||
self.acme_cache.mutate(|c| {
|
|
||||||
c.extend(cache_entries.iter().map(|(k, v)| (k.clone(), v.clone())));
|
|
||||||
});
|
|
||||||
|
|
||||||
let cert = async_acme::rustls_helper::order(
|
let cert = self
|
||||||
|identifier, cert| {
|
.in_progress
|
||||||
let domain = InternedString::from_display(&identifier);
|
.send_modify(|map| {
|
||||||
if let Some(entry) = cache_entries.get(&domain) {
|
if let Some(fut) = map.get(san_info).cloned() {
|
||||||
entry.send(Some(Arc::new(cert)));
|
if fut.peek().map_or(true, |f| f.is_some()) {
|
||||||
|
return fut;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
let provider = provider.clone();
|
||||||
},
|
let acme_cache = self.acme_cache.clone();
|
||||||
provider.0.as_str(),
|
let db = self.db.clone();
|
||||||
&identifiers,
|
let fut = async move {
|
||||||
Some(&AcmeCertCache(&self.db)),
|
acme_cache.mutate(|c| {
|
||||||
&contact,
|
c.extend(cache_entries.iter().map(|(k, v)| (k.clone(), v.clone())));
|
||||||
)
|
});
|
||||||
.await
|
|
||||||
.log_err()?;
|
|
||||||
|
|
||||||
self.acme_cache
|
let cert = tokio::time::timeout(
|
||||||
.mutate(|c| c.retain(|c, _| !cache_entries.contains_key(c)));
|
Duration::from_secs(120),
|
||||||
|
async_acme::rustls_helper::order(
|
||||||
|
|identifier, cert| {
|
||||||
|
let domain = InternedString::from_display(&identifier);
|
||||||
|
if let Some(entry) = cache_entries.get(&domain) {
|
||||||
|
entry.send(Some(Arc::new(cert)));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
provider.0.as_str(),
|
||||||
|
&identifiers,
|
||||||
|
Some(&AcmeCertCache(&db)),
|
||||||
|
&contact,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.log_err()?
|
||||||
|
.log_err()?;
|
||||||
|
|
||||||
self.in_progress.send_modify(|i| i.remove(san_info));
|
acme_cache.mutate(|c| c.retain(|c, _| !cache_entries.contains_key(c)));
|
||||||
|
|
||||||
|
Some(cert)
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
.shared();
|
||||||
|
map.insert(san_info.clone(), fut.clone());
|
||||||
|
fut
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
return Some(cert);
|
return Some(cert);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -166,7 +179,8 @@ where
|
|||||||
+ DbAccessMut<AcmeCertStore>
|
+ DbAccessMut<AcmeCertStore>
|
||||||
+ HasModel<Model = Model<M>>
|
+ HasModel<Model = Model<M>>
|
||||||
+ Send
|
+ Send
|
||||||
+ Sync,
|
+ Sync
|
||||||
|
+ 'static,
|
||||||
S: GetAcmeProvider + Clone + Send + Sync,
|
S: GetAcmeProvider + Clone + Send + Sync,
|
||||||
{
|
{
|
||||||
async fn get_config(
|
async fn get_config(
|
||||||
@@ -462,6 +476,7 @@ impl ValueParserFactory for AcmeProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, Parser, TS)]
|
#[derive(Deserialize, Serialize, Parser, TS)]
|
||||||
|
#[group(skip)]
|
||||||
#[ts(export)]
|
#[ts(export)]
|
||||||
pub struct InitAcmeParams {
|
pub struct InitAcmeParams {
|
||||||
#[arg(long, help = "help.arg.acme-provider")]
|
#[arg(long, help = "help.arg.acme-provider")]
|
||||||
@@ -488,6 +503,7 @@ pub async fn init(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, Parser, TS)]
|
#[derive(Deserialize, Serialize, Parser, TS)]
|
||||||
|
#[group(skip)]
|
||||||
#[ts(export)]
|
#[ts(export)]
|
||||||
pub struct RemoveAcmeParams {
|
pub struct RemoveAcmeParams {
|
||||||
#[arg(long, help = "help.arg.acme-provider")]
|
#[arg(long, help = "help.arg.acme-provider")]
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ pub struct PassthroughInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
|
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
|
||||||
|
#[group(skip)]
|
||||||
#[serde(rename_all = "kebab-case")]
|
#[serde(rename_all = "kebab-case")]
|
||||||
struct AddPassthroughParams {
|
struct AddPassthroughParams {
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@@ -79,6 +80,7 @@ struct AddPassthroughParams {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
|
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
|
||||||
|
#[group(skip)]
|
||||||
#[serde(rename_all = "kebab-case")]
|
#[serde(rename_all = "kebab-case")]
|
||||||
struct RemovePassthroughParams {
|
struct RemovePassthroughParams {
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@@ -959,7 +961,8 @@ where
|
|||||||
+ DbAccessMut<AcmeCertStore>
|
+ DbAccessMut<AcmeCertStore>
|
||||||
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
|
+ DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
|
||||||
+ Send
|
+ Send
|
||||||
+ Sync,
|
+ Sync
|
||||||
|
+ 'static,
|
||||||
A: Accept + 'static,
|
A: Accept + 'static,
|
||||||
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
|
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
|
||||||
+ Visit<ExtractVisitor<GatewayInfo>>
|
+ Visit<ExtractVisitor<GatewayInfo>>
|
||||||
@@ -1088,7 +1091,7 @@ impl<A: Accept> VHostServer<A> {
|
|||||||
acme_cache,
|
acme_cache,
|
||||||
crypto_provider: crypto_provider.clone(),
|
crypto_provider: crypto_provider.clone(),
|
||||||
get_provider: GetVHostAcmeProvider(mapping.clone()),
|
get_provider: GetVHostAcmeProvider(mapping.clone()),
|
||||||
in_progress: Watch::new(BTreeSet::new()),
|
in_progress: Watch::new(BTreeMap::new()),
|
||||||
}),
|
}),
|
||||||
RootCaTlsHandler {
|
RootCaTlsHandler {
|
||||||
db,
|
db,
|
||||||
|
|||||||
Reference in New Issue
Block a user