diff --git a/core/src/net/acme.rs b/core/src/net/acme.rs index 68a352ae7..ccf44a998 100644 --- a/core/src/net/acme.rs +++ b/core/src/net/acme.rs @@ -2,11 +2,13 @@ use std::collections::{BTreeMap, BTreeSet}; use std::net::IpAddr; use std::str::FromStr; use std::sync::Arc; +use std::time::Duration; use async_acme::acme::{ACME_TLS_ALPN_NAME, Identifier}; use clap::Parser; use clap::builder::ValueParserFactory; -use futures::StreamExt; +use futures::future::{BoxFuture, Shared}; +use futures::{FutureExt, StreamExt}; use imbl_value::InternedString; use itertools::Itertools; use openssl::pkey::{PKey, Private}; @@ -42,7 +44,8 @@ pub struct AcmeTlsHandler { pub acme_cache: AcmeTlsAlpnCache, pub crypto_provider: Arc, pub get_provider: S, - pub in_progress: Watch>>, + pub in_progress: + Watch, Shared>>>>, } impl AcmeTlsHandler where @@ -50,8 +53,9 @@ where + DbAccessMut + HasModel> + Send - + Sync, - S: GetAcmeProvider + Clone, + + Sync + + 'static, + S: GetAcmeProvider + Clone + Send + Sync, { pub async fn get_cert(&self, san_info: &BTreeSet) -> Option { let provider = self.get_provider.get_provider(san_info).await?; @@ -71,6 +75,8 @@ where .and_then(|c| should_use_cert(&c.0).log_err()) .unwrap_or(false) { + self.in_progress + .send_if_modified(|map| map.remove(san_info).is_some()); return Some( CertifiedKey::from_der( cert.fullchain @@ -88,21 +94,6 @@ where } } - if !self.in_progress.send_if_modified(|x| { - if !x.contains(san_info) { - x.insert(san_info.clone()); - true - } else { - false - } - }) { - self.in_progress - .clone() - .wait_for(|x| !x.contains(san_info)) - .await; - continue; - } - let contact = >::access_by_key(&peek, &provider)? .as_contact() .de() @@ -121,31 +112,53 @@ where .cloned() .map(|d| (d, Watch::new(None))) .collect::>(); - self.acme_cache.mutate(|c| { - c.extend(cache_entries.iter().map(|(k, v)| (k.clone(), v.clone()))); - }); - let cert = async_acme::rustls_helper::order( - |identifier, cert| { - let domain = InternedString::from_display(&identifier); - if let Some(entry) = cache_entries.get(&domain) { - entry.send(Some(Arc::new(cert))); + let cert = self + .in_progress + .send_modify(|map| { + if let Some(fut) = map.get(san_info).cloned() { + if fut.peek().map_or(true, |f| f.is_some()) { + return fut; + } } - Ok(()) - }, - provider.0.as_str(), - &identifiers, - Some(&AcmeCertCache(&self.db)), - &contact, - ) - .await - .log_err()?; + let provider = provider.clone(); + let acme_cache = self.acme_cache.clone(); + let db = self.db.clone(); + let fut = async move { + acme_cache.mutate(|c| { + c.extend(cache_entries.iter().map(|(k, v)| (k.clone(), v.clone()))); + }); - self.acme_cache - .mutate(|c| c.retain(|c, _| !cache_entries.contains_key(c))); + let cert = tokio::time::timeout( + Duration::from_secs(120), + async_acme::rustls_helper::order( + |identifier, cert| { + let domain = InternedString::from_display(&identifier); + if let Some(entry) = cache_entries.get(&domain) { + entry.send(Some(Arc::new(cert))); + } + Ok(()) + }, + provider.0.as_str(), + &identifiers, + Some(&AcmeCertCache(&db)), + &contact, + ), + ) + .await + .log_err()? + .log_err()?; - self.in_progress.send_modify(|i| i.remove(san_info)); + acme_cache.mutate(|c| c.retain(|c, _| !cache_entries.contains_key(c))); + Some(cert) + } + .boxed() + .shared(); + map.insert(san_info.clone(), fut.clone()); + fut + }) + .await?; return Some(cert); } } @@ -166,7 +179,8 @@ where + DbAccessMut + HasModel> + Send - + Sync, + + Sync + + 'static, S: GetAcmeProvider + Clone + Send + Sync, { async fn get_config( @@ -462,6 +476,7 @@ impl ValueParserFactory for AcmeProvider { } #[derive(Deserialize, Serialize, Parser, TS)] +#[group(skip)] #[ts(export)] pub struct InitAcmeParams { #[arg(long, help = "help.arg.acme-provider")] @@ -488,6 +503,7 @@ pub async fn init( } #[derive(Deserialize, Serialize, Parser, TS)] +#[group(skip)] #[ts(export)] pub struct RemoveAcmeParams { #[arg(long, help = "help.arg.acme-provider")] diff --git a/core/src/net/vhost.rs b/core/src/net/vhost.rs index 6b4962e50..3e5fae4db 100644 --- a/core/src/net/vhost.rs +++ b/core/src/net/vhost.rs @@ -64,6 +64,7 @@ pub struct PassthroughInfo { } #[derive(Debug, Clone, Deserialize, Serialize, Parser)] +#[group(skip)] #[serde(rename_all = "kebab-case")] struct AddPassthroughParams { #[arg(long)] @@ -79,6 +80,7 @@ struct AddPassthroughParams { } #[derive(Debug, Clone, Deserialize, Serialize, Parser)] +#[group(skip)] #[serde(rename_all = "kebab-case")] struct RemovePassthroughParams { #[arg(long)] @@ -959,7 +961,8 @@ where + DbAccessMut + DbAccessByKey = &'a AcmeProvider> + Send - + Sync, + + Sync + + 'static, A: Accept + 'static, ::Metadata: Visit> + Visit> @@ -1088,7 +1091,7 @@ impl VHostServer { acme_cache, crypto_provider: crypto_provider.clone(), get_provider: GetVHostAcmeProvider(mapping.clone()), - in_progress: Watch::new(BTreeSet::new()), + in_progress: Watch::new(BTreeMap::new()), }), RootCaTlsHandler { db,