Files
start-os/core/src/service/effects/callbacks.rs
Aiden McClelland 49d4da03ca feat: refactor NetService to watch DB and reconcile network state
- NetService sync task now uses PatchDB DbWatch instead of being called
  directly after DB mutations
- Read gateways from DB instead of network interface context when
  updating host addresses
- gateway sync updates all host addresses in the DB
- Add Watch<u64> channel for callers to wait on sync completion
- Fix ts-rs codegen bug with #[ts(skip)] on flattened Plugin field
- Update SDK getServiceInterface.ts for new HostnameInfo shape
- Remove unnecessary HTTPS redirect in static_server.rs
- Fix tunnel/api.rs to filter for WAN IPv4 address
2026-02-13 16:21:57 -07:00

375 lines
13 KiB
Rust

use std::cmp::min;
use std::collections::{BTreeMap, BTreeSet};
use std::sync::{Arc, Mutex, Weak};
use std::time::{Duration, SystemTime};
use clap::Parser;
use futures::future::join_all;
use imbl::{Vector, vector};
use imbl_value::InternedString;
use serde::{Deserialize, Serialize};
use tracing::warn;
use ts_rs::TS;
use patch_db::json_ptr::JsonPointer;
use crate::db::model::Database;
use crate::net::ssl::FullchainCertData;
use crate::prelude::*;
use crate::service::effects::context::EffectContext;
use crate::service::effects::net::ssl::Algorithm;
use crate::service::rpc::{CallbackHandle, CallbackId};
use crate::service::{Service, ServiceActorSeed};
use crate::util::collections::EqMap;
use crate::util::future::NonDetachingJoinHandle;
use crate::{HostId, PackageId, ServiceInterfaceId};
#[derive(Default)]
pub struct ServiceCallbacks(Mutex<ServiceCallbackMap>);
#[derive(Default)]
struct ServiceCallbackMap {
get_service_interface: BTreeMap<(PackageId, ServiceInterfaceId), Vec<CallbackHandler>>,
list_service_interfaces: BTreeMap<PackageId, Vec<CallbackHandler>>,
get_system_smtp: Vec<CallbackHandler>,
get_host_info: BTreeMap<(PackageId, HostId), (NonDetachingJoinHandle<()>, Vec<CallbackHandler>)>,
get_ssl_certificate: EqMap<
(BTreeSet<InternedString>, FullchainCertData, Algorithm),
(NonDetachingJoinHandle<()>, Vec<CallbackHandler>),
>,
get_status: BTreeMap<PackageId, Vec<CallbackHandler>>,
get_container_ip: BTreeMap<PackageId, Vec<CallbackHandler>>,
get_service_manifest: BTreeMap<PackageId, Vec<CallbackHandler>>,
}
impl ServiceCallbacks {
fn mutate<T>(&self, f: impl FnOnce(&mut ServiceCallbackMap) -> T) -> T {
let mut this = self.0.lock().unwrap();
f(&mut *this)
}
pub fn gc(&self) {
self.mutate(|this| {
this.get_service_interface.retain(|_, v| {
v.retain(|h| h.handle.is_active() && h.seed.strong_count() > 0);
!v.is_empty()
});
this.list_service_interfaces.retain(|_, v| {
v.retain(|h| h.handle.is_active() && h.seed.strong_count() > 0);
!v.is_empty()
});
this.get_system_smtp
.retain(|h| h.handle.is_active() && h.seed.strong_count() > 0);
this.get_host_info.retain(|_, (_, v)| {
v.retain(|h| h.handle.is_active() && h.seed.strong_count() > 0);
!v.is_empty()
});
this.get_ssl_certificate.retain(|_, (_, v)| {
v.retain(|h| h.handle.is_active() && h.seed.strong_count() > 0);
!v.is_empty()
});
this.get_status.retain(|_, v| {
v.retain(|h| h.handle.is_active() && h.seed.strong_count() > 0);
!v.is_empty()
});
this.get_service_manifest.retain(|_, v| {
v.retain(|h| h.handle.is_active() && h.seed.strong_count() > 0);
!v.is_empty()
});
})
}
pub(super) fn add_get_service_interface(
&self,
package_id: PackageId,
service_interface_id: ServiceInterfaceId,
handler: CallbackHandler,
) {
self.mutate(|this| {
this.get_service_interface
.entry((package_id, service_interface_id))
.or_default()
.push(handler);
})
}
#[must_use]
pub fn get_service_interface(
&self,
id: &(PackageId, ServiceInterfaceId),
) -> Option<CallbackHandlers> {
self.mutate(|this| {
Some(CallbackHandlers(
this.get_service_interface.remove(id).unwrap_or_default(),
))
.filter(|cb| !cb.0.is_empty())
})
}
pub(super) fn add_list_service_interfaces(
&self,
package_id: PackageId,
handler: CallbackHandler,
) {
self.mutate(|this| {
this.list_service_interfaces
.entry(package_id)
.or_default()
.push(handler);
})
}
#[must_use]
pub fn list_service_interfaces(&self, id: &PackageId) -> Option<CallbackHandlers> {
self.mutate(|this| {
Some(CallbackHandlers(
this.list_service_interfaces.remove(id).unwrap_or_default(),
))
.filter(|cb| !cb.0.is_empty())
})
}
pub(super) fn add_get_system_smtp(&self, handler: CallbackHandler) {
self.mutate(|this| {
this.get_system_smtp.push(handler);
})
}
#[must_use]
pub fn get_system_smtp(&self) -> Option<CallbackHandlers> {
self.mutate(|this| {
Some(CallbackHandlers(std::mem::take(&mut this.get_system_smtp)))
.filter(|cb| !cb.0.is_empty())
})
}
pub(super) fn add_get_host_info(
self: &Arc<Self>,
db: &TypedPatchDb<Database>,
package_id: PackageId,
host_id: HostId,
handler: CallbackHandler,
) {
self.mutate(|this| {
this.get_host_info
.entry((package_id.clone(), host_id.clone()))
.or_insert_with(|| {
let ptr: JsonPointer = format!(
"/public/packageData/{}/hosts/{}",
package_id, host_id
)
.parse()
.expect("valid json pointer");
let db = db.clone();
let callbacks = Arc::clone(self);
let key = (package_id, host_id);
(
tokio::spawn(async move {
let mut sub = db.subscribe(ptr).await;
while sub.recv().await.is_some() {
if let Some(cbs) = callbacks.mutate(|this| {
this.get_host_info
.remove(&key)
.map(|(_, handlers)| CallbackHandlers(handlers))
.filter(|cb| !cb.0.is_empty())
}) {
if let Err(e) = cbs.call(vector![]).await {
tracing::error!(
"Error in host info callback: {e}"
);
tracing::debug!("{e:?}");
}
}
// entry was removed when we consumed handlers,
// so stop watching — a new subscription will be
// created if the service re-registers
break;
}
})
.into(),
Vec::new(),
)
})
.1
.push(handler);
})
}
pub(super) fn add_get_ssl_certificate(
&self,
ctx: EffectContext,
hostnames: BTreeSet<InternedString>,
cert: FullchainCertData,
algorithm: Algorithm,
handler: CallbackHandler,
) {
self.mutate(|this| {
this.get_ssl_certificate
.entry((hostnames.clone(), cert.clone(), algorithm))
.or_insert_with(|| {
(
tokio::spawn(async move {
if let Err(e) = async {
loop {
match cert
.expiration()
.ok()
.and_then(|e| e.duration_since(SystemTime::now()).ok())
{
Some(d) => {
tokio::time::sleep(min(Duration::from_secs(86400), d))
.await
}
_ => break,
}
}
let Ok(ctx) = ctx.deref() else {
return Ok(());
};
if let Some((_, callbacks)) =
ctx.seed.ctx.callbacks.mutate(|this| {
this.get_ssl_certificate
.remove(&(hostnames, cert, algorithm))
})
{
CallbackHandlers(callbacks).call(vector![]).await?;
}
Ok::<_, Error>(())
}
.await
{
tracing::error!(
"Error in callback handler for getSslCertificate: {e}"
);
tracing::debug!("{e:?}");
}
})
.into(),
Vec::new(),
)
})
.1
.push(handler);
})
}
pub(super) fn add_get_status(&self, package_id: PackageId, handler: CallbackHandler) {
self.mutate(|this| this.get_status.entry(package_id).or_default().push(handler))
}
#[must_use]
pub fn get_status(&self, package_id: &PackageId) -> Option<CallbackHandlers> {
self.mutate(|this| {
if let Some(watched) = this.get_status.remove(package_id) {
Some(CallbackHandlers(watched))
} else {
None
}
.filter(|cb| !cb.0.is_empty())
})
}
pub(super) fn add_get_container_ip(&self, package_id: PackageId, handler: CallbackHandler) {
self.mutate(|this| {
this.get_container_ip
.entry(package_id)
.or_default()
.push(handler)
})
}
#[must_use]
pub fn get_container_ip(&self, package_id: &PackageId) -> Option<CallbackHandlers> {
self.mutate(|this| {
this.get_container_ip
.remove(package_id)
.map(CallbackHandlers)
.filter(|cb| !cb.0.is_empty())
})
}
pub(super) fn add_get_service_manifest(&self, package_id: PackageId, handler: CallbackHandler) {
self.mutate(|this| {
this.get_service_manifest
.entry(package_id)
.or_default()
.push(handler)
})
}
#[must_use]
pub fn get_service_manifest(&self, package_id: &PackageId) -> Option<CallbackHandlers> {
self.mutate(|this| {
this.get_service_manifest
.remove(package_id)
.map(CallbackHandlers)
.filter(|cb| !cb.0.is_empty())
})
}
}
pub struct CallbackHandler {
handle: CallbackHandle,
seed: Weak<ServiceActorSeed>,
}
impl CallbackHandler {
pub fn new(service: &Service, handle: CallbackHandle) -> Self {
Self {
handle,
seed: Arc::downgrade(&service.seed),
}
}
pub async fn call(mut self, args: Vector<Value>) -> Result<(), Error> {
if let Some(seed) = self.seed.upgrade() {
seed.persistent_container
.callback(self.handle.take(), args)
.await?;
}
Ok(())
}
}
impl Drop for CallbackHandler {
fn drop(&mut self) {
if self.handle.is_active() {
warn!("Callback handler dropped while still active!");
}
}
}
pub struct CallbackHandlers(Vec<CallbackHandler>);
impl CallbackHandlers {
pub async fn call(self, args: Vector<Value>) -> Result<(), Error> {
let mut err = ErrorCollection::new();
for res in join_all(self.0.into_iter().map(|cb| cb.call(args.clone()))).await {
err.handle(res);
}
err.into_result()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TS, Parser)]
#[ts(type = "{ only: number[] } | { except: number[] }")]
#[ts(export)]
pub struct ClearCallbacksParams {
#[arg(long, conflicts_with = "except", help = "help.arg.only-callbacks")]
pub only: Option<Vec<CallbackId>>,
#[arg(long, conflicts_with = "only", help = "help.arg.except-callbacks")]
pub except: Option<Vec<CallbackId>>,
}
pub(super) fn clear_callbacks(
context: EffectContext,
ClearCallbacksParams { only, except }: ClearCallbacksParams,
) -> Result<(), Error> {
let context = context.deref()?;
let only = only.map(|only| only.into_iter().collect::<BTreeSet<_>>());
let except = except.map(|except| except.into_iter().collect::<BTreeSet<_>>());
context.seed.persistent_container.state.send_modify(|s| {
s.callbacks.retain(|cb| {
only.as_ref().map_or(true, |only| !only.contains(cb))
&& except.as_ref().map_or(true, |except| except.contains(cb))
})
});
context.seed.ctx.callbacks.gc();
Ok(())
}