diff --git a/container-runtime/src/Adapters/EffectCreator.ts b/container-runtime/src/Adapters/EffectCreator.ts index bcc040f4f..daba0304e 100644 --- a/container-runtime/src/Adapters/EffectCreator.ts +++ b/container-runtime/src/Adapters/EffectCreator.ts @@ -253,6 +253,14 @@ export function makeEffects(context: EffectContext): Effects { callback: context.callbacks?.addCallback(options.callback) || null, }) as ReturnType }, + getOutboundGateway( + ...[options]: Parameters + ) { + return rpcRound("get-outbound-gateway", { + ...options, + callback: context.callbacks?.addCallback(options.callback) || null, + }) as ReturnType + }, listServiceInterfaces( ...[options]: Parameters ) { diff --git a/core/src/service/effects/callbacks.rs b/core/src/service/effects/callbacks.rs index 58d6cd8ff..d30665c96 100644 --- a/core/src/service/effects/callbacks.rs +++ b/core/src/service/effects/callbacks.rs @@ -5,15 +5,16 @@ use std::time::{Duration, SystemTime}; use clap::Parser; use futures::future::join_all; -use imbl::{Vector, vector}; +use imbl::{OrdMap, Vector, vector}; use imbl_value::InternedString; +use patch_db::TypedDbWatch; +use patch_db::json_ptr::JsonPointer; use serde::{Deserialize, Serialize}; use tracing::warn; use ts_rs::TS; -use patch_db::json_ptr::JsonPointer; - use crate::db::model::Database; +use crate::db::model::public::NetworkInterfaceInfo; use crate::net::ssl::FullchainCertData; use crate::prelude::*; use crate::service::effects::context::EffectContext; @@ -22,7 +23,7 @@ use crate::service::rpc::{CallbackHandle, CallbackId}; use crate::service::{Service, ServiceActorSeed}; use crate::util::collections::EqMap; use crate::util::future::NonDetachingJoinHandle; -use crate::{HostId, PackageId, ServiceInterfaceId}; +use crate::{GatewayId, HostId, PackageId, ServiceInterfaceId}; #[derive(Default)] pub struct ServiceCallbacks(Mutex); @@ -32,7 +33,8 @@ struct ServiceCallbackMap { get_service_interface: BTreeMap<(PackageId, ServiceInterfaceId), Vec>, list_service_interfaces: BTreeMap>, get_system_smtp: Vec, - get_host_info: BTreeMap<(PackageId, HostId), (NonDetachingJoinHandle<()>, Vec)>, + get_host_info: + BTreeMap<(PackageId, HostId), (NonDetachingJoinHandle<()>, Vec)>, get_ssl_certificate: EqMap< (BTreeSet, FullchainCertData, Algorithm), (NonDetachingJoinHandle<()>, Vec), @@ -40,6 +42,7 @@ struct ServiceCallbackMap { get_status: BTreeMap>, get_container_ip: BTreeMap>, get_service_manifest: BTreeMap>, + get_outbound_gateway: BTreeMap, Vec)>, } impl ServiceCallbacks { @@ -76,6 +79,10 @@ impl ServiceCallbacks { v.retain(|h| h.handle.is_active() && h.seed.strong_count() > 0); !v.is_empty() }); + this.get_outbound_gateway.retain(|_, (_, v)| { + v.retain(|h| h.handle.is_active() && h.seed.strong_count() > 0); + !v.is_empty() + }); }) } @@ -154,12 +161,10 @@ impl ServiceCallbacks { this.get_host_info .entry((package_id.clone(), host_id.clone())) .or_insert_with(|| { - let ptr: JsonPointer = format!( - "/public/packageData/{}/hosts/{}", - package_id, host_id - ) - .parse() - .expect("valid json pointer"); + let ptr: JsonPointer = + format!("/public/packageData/{}/hosts/{}", package_id, host_id) + .parse() + .expect("valid json pointer"); let db = db.clone(); let callbacks = Arc::clone(self); let key = (package_id, host_id); @@ -174,9 +179,7 @@ impl ServiceCallbacks { .filter(|cb| !cb.0.is_empty()) }) { if let Err(e) = cbs.call(vector![]).await { - tracing::error!( - "Error in host info callback: {e}" - ); + tracing::error!("Error in host info callback: {e}"); tracing::debug!("{e:?}"); } } @@ -287,6 +290,61 @@ impl ServiceCallbacks { }) } + /// Register a callback for outbound gateway changes. + pub(super) fn add_get_outbound_gateway( + self: &Arc, + package_id: PackageId, + mut outbound_gateway: TypedDbWatch>, + mut default_outbound: Option>>, + mut fallback: Option>>, + handler: CallbackHandler, + ) { + self.mutate(|this| { + this.get_outbound_gateway + .entry(package_id.clone()) + .or_insert_with(|| { + let callbacks = Arc::clone(self); + let key = package_id; + ( + tokio::spawn(async move { + tokio::select! { + _ = outbound_gateway.changed() => {} + _ = async { + if let Some(ref mut w) = default_outbound { + let _ = w.changed().await; + } else { + std::future::pending::<()>().await; + } + } => {} + _ = async { + if let Some(ref mut w) = fallback { + let _ = w.changed().await; + } else { + std::future::pending::<()>().await; + } + } => {} + } + if let Some(cbs) = callbacks.mutate(|this| { + this.get_outbound_gateway + .remove(&key) + .map(|(_, handlers)| CallbackHandlers(handlers)) + .filter(|cb| !cb.0.is_empty()) + }) { + if let Err(e) = cbs.call(vector![]).await { + tracing::error!("Error in outbound gateway callback: {e}"); + tracing::debug!("{e:?}"); + } + } + }) + .into(), + Vec::new(), + ) + }) + .1 + .push(handler); + }) + } + pub(super) fn add_get_service_manifest(&self, package_id: PackageId, handler: CallbackHandler) { self.mutate(|this| { this.get_service_manifest diff --git a/core/src/service/effects/mod.rs b/core/src/service/effects/mod.rs index 07faaf4af..e3116da13 100644 --- a/core/src/service/effects/mod.rs +++ b/core/src/service/effects/mod.rs @@ -143,6 +143,10 @@ pub fn handler() -> ParentHandler { "get-container-ip", from_fn_async(net::info::get_container_ip).no_cli(), ) + .subcommand( + "get-outbound-gateway", + from_fn_async(net::info::get_outbound_gateway).no_cli(), + ) .subcommand( "get-os-ip", from_fn(|_: C| Ok::<_, Error>(Ipv4Addr::from(HOST_IP))), diff --git a/core/src/service/effects/net/info.rs b/core/src/service/effects/net/info.rs index f14ee72dc..ef8507e47 100644 --- a/core/src/service/effects/net/info.rs +++ b/core/src/service/effects/net/info.rs @@ -1,9 +1,16 @@ use std::net::Ipv4Addr; -use crate::PackageId; +use imbl::OrdMap; +use patch_db::TypedDbWatch; +use patch_db::json_ptr::JsonPointer; +use tokio::process::Command; + +use crate::db::model::public::NetworkInterfaceInfo; use crate::service::effects::callbacks::CallbackHandler; use crate::service::effects::prelude::*; use crate::service::rpc::CallbackId; +use crate::util::Invoke; +use crate::{GatewayId, PackageId}; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, TS)] #[serde(rename_all = "camelCase")] @@ -51,3 +58,116 @@ pub async fn get_container_ip( lxc.ip().await.map(Some) } } + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, TS)] +#[serde(rename_all = "camelCase")] +#[ts(export)] +pub struct GetOutboundGatewayParams { + #[ts(optional)] + callback: Option, +} + +pub async fn get_outbound_gateway( + context: EffectContext, + GetOutboundGatewayParams { callback }: GetOutboundGatewayParams, +) -> Result { + let context = context.deref()?; + let ctx = &context.seed.ctx; + + // Resolve the effective gateway; DB watches are created atomically + // with each read to avoid race conditions. + let (gw, pkg_watch, os_watch, gateways_watch) = + resolve_outbound_gateway(ctx, &context.seed.id).await?; + + if let Some(callback) = callback { + let callback = callback.register(&context.seed.persistent_container); + context.seed.ctx.callbacks.add_get_outbound_gateway( + context.seed.id.clone(), + pkg_watch, + os_watch, + gateways_watch, + CallbackHandler::new(&context, callback), + ); + } + + Ok(gw) +} + +async fn resolve_outbound_gateway( + ctx: &crate::context::RpcContext, + package_id: &PackageId, +) -> Result< + ( + GatewayId, + TypedDbWatch>, + Option>>, + Option>>, + ), + Error, +> { + // 1. Package-specific outbound gateway — subscribe before reading + let pkg_ptr: JsonPointer = format!("/public/packageData/{}/outboundGateway", package_id) + .parse() + .expect("valid json pointer"); + let mut pkg_watch = ctx.db.watch(pkg_ptr).await; + let pkg_gw: Option = imbl_value::from_value(pkg_watch.peek_and_mark_seen()?)?; + + if let Some(gw) = pkg_gw { + return Ok((gw, pkg_watch.typed(), None, None)); + } + + // 2. OS-level default outbound — subscribe before reading + let os_ptr: JsonPointer = "/public/serverInfo/network/defaultOutbound" + .parse() + .expect("valid json pointer"); + let mut os_watch = ctx.db.watch(os_ptr).await; + let default_outbound: Option = + imbl_value::from_value(os_watch.peek_and_mark_seen()?)?; + + if let Some(gw) = default_outbound { + return Ok((gw, pkg_watch.typed(), Some(os_watch.typed()), None)); + } + + // 3. Fall through to main routing table — watch gateways for changes + let gw_ptr: JsonPointer = "/public/serverInfo/network/gateways" + .parse() + .expect("valid json pointer"); + let mut gateways_watch = ctx.db.watch(gw_ptr).await; + gateways_watch.peek_and_mark_seen()?; + + let gw = default_route_interface().await?; + Ok(( + gw, + pkg_watch.typed(), + Some(os_watch.typed()), + Some(gateways_watch.typed()), + )) +} + +/// Parses `ip route show table main` for the default route's `dev` field. +async fn default_route_interface() -> Result { + let output = Command::new("ip") + .arg("route") + .arg("show") + .arg("table") + .arg("main") + .invoke(ErrorKind::Network) + .await?; + let text = String::from_utf8_lossy(&output); + for line in text.lines() { + if line.starts_with("default ") { + let mut parts = line.split_whitespace(); + while let Some(tok) = parts.next() { + if tok == "dev" { + if let Some(dev) = parts.next() { + return Ok(dev.parse().unwrap()); + } + } + } + } + } + Err(Error::new( + eyre!("no default route found in main routing table"), + ErrorKind::Network, + )) +} diff --git a/sdk/base/lib/Effects.ts b/sdk/base/lib/Effects.ts index 596fffe92..d3d0b8923 100644 --- a/sdk/base/lib/Effects.ts +++ b/sdk/base/lib/Effects.ts @@ -135,6 +135,8 @@ export type Effects = { }): Promise /** Returns the IP address of StartOS */ getOsIp(): Promise + /** Returns the effective outbound gateway for this service */ + getOutboundGateway(options: { callback?: () => void }): Promise // interface /** Creates an interface bound to a specific host and port to show to the user */ exportServiceInterface(options: ExportServiceInterfaceParams): Promise diff --git a/sdk/base/lib/osBindings/GetOutboundGatewayParams.ts b/sdk/base/lib/osBindings/GetOutboundGatewayParams.ts new file mode 100644 index 000000000..703fb4f08 --- /dev/null +++ b/sdk/base/lib/osBindings/GetOutboundGatewayParams.ts @@ -0,0 +1,4 @@ +// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually. +import type { CallbackId } from './CallbackId' + +export type GetOutboundGatewayParams = { callback?: CallbackId } diff --git a/sdk/base/lib/osBindings/index.ts b/sdk/base/lib/osBindings/index.ts index 57fbb9352..1b683f950 100644 --- a/sdk/base/lib/osBindings/index.ts +++ b/sdk/base/lib/osBindings/index.ts @@ -110,6 +110,7 @@ export { GetContainerIpParams } from './GetContainerIpParams' export { GetHostInfoParams } from './GetHostInfoParams' export { GetOsAssetParams } from './GetOsAssetParams' export { GetOsVersionParams } from './GetOsVersionParams' +export { GetOutboundGatewayParams } from './GetOutboundGatewayParams' export { GetPackageParams } from './GetPackageParams' export { GetPackageResponseFull } from './GetPackageResponseFull' export { GetPackageResponse } from './GetPackageResponse' diff --git a/sdk/base/lib/test/startosTypeValidation.test.ts b/sdk/base/lib/test/startosTypeValidation.test.ts index 52da23a7e..f02336d7e 100644 --- a/sdk/base/lib/test/startosTypeValidation.test.ts +++ b/sdk/base/lib/test/startosTypeValidation.test.ts @@ -25,6 +25,7 @@ import { GetSslKeyParams } from '.././osBindings' import { GetServiceInterfaceParams } from '.././osBindings' import { SetDependenciesParams } from '.././osBindings' import { GetSystemSmtpParams } from '.././osBindings' +import { GetOutboundGatewayParams } from '.././osBindings' import { GetServicePortForwardParams } from '.././osBindings' import { ExportServiceInterfaceParams } from '.././osBindings' import { ListServiceInterfacesParams } from '.././osBindings' @@ -83,6 +84,7 @@ describe('startosTypeValidation ', () => { getServiceManifest: {} as WithCallback, getSystemSmtp: {} as WithCallback, getContainerIp: {} as WithCallback, + getOutboundGateway: {} as WithCallback, getOsIp: undefined, getServicePortForward: {} as GetServicePortForwardParams, clearServiceInterfaces: {} as ClearServiceInterfacesParams, diff --git a/sdk/base/lib/util/GetOutboundGateway.ts b/sdk/base/lib/util/GetOutboundGateway.ts new file mode 100644 index 000000000..16a8f95da --- /dev/null +++ b/sdk/base/lib/util/GetOutboundGateway.ts @@ -0,0 +1,105 @@ +import { Effects } from '../Effects' +import { DropGenerator, DropPromise } from './Drop' + +export class GetOutboundGateway { + constructor(readonly effects: Effects) {} + + /** + * Returns the effective outbound gateway. Reruns the context from which it has been called if the underlying value changes + */ + const() { + return this.effects.getOutboundGateway({ + callback: + this.effects.constRetry && + (() => this.effects.constRetry && this.effects.constRetry()), + }) + } + /** + * Returns the effective outbound gateway. Does nothing if the value changes + */ + once() { + return this.effects.getOutboundGateway({}) + } + + private async *watchGen(abort?: AbortSignal) { + const resolveCell = { resolve: () => {} } + this.effects.onLeaveContext(() => { + resolveCell.resolve() + }) + abort?.addEventListener('abort', () => resolveCell.resolve()) + while (this.effects.isInContext && !abort?.aborted) { + let callback: () => void = () => {} + const waitForNext = new Promise((resolve) => { + callback = resolve + resolveCell.resolve = resolve + }) + yield await this.effects.getOutboundGateway({ + callback: () => callback(), + }) + await waitForNext + } + return new Promise((_, rej) => rej(new Error('aborted'))) + } + + /** + * Watches the effective outbound gateway. Returns an async iterator that yields whenever the value changes + */ + watch(abort?: AbortSignal): AsyncGenerator { + const ctrl = new AbortController() + abort?.addEventListener('abort', () => ctrl.abort()) + return DropGenerator.of(this.watchGen(ctrl.signal), () => ctrl.abort()) + } + + /** + * Watches the effective outbound gateway. Takes a custom callback function to run whenever the value changes + */ + onChange( + callback: ( + value: string, + error?: Error, + ) => { cancel: boolean } | Promise<{ cancel: boolean }>, + ) { + ;(async () => { + const ctrl = new AbortController() + for await (const value of this.watch(ctrl.signal)) { + try { + const res = await callback(value) + if (res.cancel) { + ctrl.abort() + break + } + } catch (e) { + console.error( + 'callback function threw an error @ GetOutboundGateway.onChange', + e, + ) + } + } + })() + .catch((e) => callback('', e)) + .catch((e) => + console.error( + 'callback function threw an error @ GetOutboundGateway.onChange', + e, + ), + ) + } + + /** + * Watches the effective outbound gateway. Returns when the predicate is true + */ + waitFor(pred: (value: string) => boolean): Promise { + const ctrl = new AbortController() + return DropPromise.of( + Promise.resolve().then(async () => { + for await (const next of this.watchGen(ctrl.signal)) { + if (pred(next)) { + return next + } + } + return '' + }), + () => ctrl.abort(), + ) + } +} diff --git a/sdk/base/lib/util/index.ts b/sdk/base/lib/util/index.ts index 303cf5f73..f26338438 100644 --- a/sdk/base/lib/util/index.ts +++ b/sdk/base/lib/util/index.ts @@ -14,6 +14,7 @@ export { once } from './once' export { asError } from './asError' export * as Patterns from './patterns' export * from './typeHelpers' +export { GetOutboundGateway } from './GetOutboundGateway' export { GetSystemSmtp } from './GetSystemSmtp' export { Graph, Vertex } from './graph' export { inMs } from './inMs' diff --git a/sdk/package/lib/StartSdk.ts b/sdk/package/lib/StartSdk.ts index d8d255af2..9eb0036d1 100644 --- a/sdk/package/lib/StartSdk.ts +++ b/sdk/package/lib/StartSdk.ts @@ -23,7 +23,7 @@ import { setupExportedUrls } from '../../base/lib/interfaces/setupExportedUrls' import { successFailure } from './trigger/successFailure' import { MultiHost, Scheme } from '../../base/lib/interfaces/Host' import { ServiceInterfaceBuilder } from '../../base/lib/interfaces/ServiceInterfaceBuilder' -import { GetSystemSmtp } from './util' +import { GetOutboundGateway, GetSystemSmtp } from './util' import { nullIfEmpty } from './util' import { getServiceInterface, getServiceInterfaces } from './util' import { @@ -107,6 +107,7 @@ export class StartSdk { type AlreadyExposed = | 'getSslCertificate' | 'getSystemSmtp' + | 'getOutboundGateway' | 'getContainerIp' | 'getDataVersion' | 'setDataVersion' @@ -445,6 +446,8 @@ export class StartSdk { ) => new ServiceInterfaceBuilder({ ...options, effects }), getSystemSmtp: (effects: E) => new GetSystemSmtp(effects), + getOutboundGateway: (effects: E) => + new GetOutboundGateway(effects), getSslCertificate: ( effects: E, hostnames: string[], diff --git a/sdk/package/lib/version/VersionGraph.ts b/sdk/package/lib/version/VersionGraph.ts index 396497de5..08109e023 100644 --- a/sdk/package/lib/version/VersionGraph.ts +++ b/sdk/package/lib/version/VersionGraph.ts @@ -64,8 +64,6 @@ export class VersionGraph private constructor( readonly current: VersionInfo, versions: Array>, - private readonly preInstall?: InitScriptOrFn<'install'>, - private readonly uninstall?: UninitScript | UninitFn, ) { this.graph = once(() => { const graph = new Graph< @@ -167,24 +165,8 @@ export class VersionGraph static of< CurrentVersion extends string, OtherVersions extends Array>, - >(options: { - current: VersionInfo - other: OtherVersions - /** - * A script to run only on fresh install - */ - preInstall?: InitScriptOrFn<'install'> - /** - * A script to run only on uninstall - */ - uninstall?: UninitScriptOrFn - }) { - return new VersionGraph( - options.current, - options.other, - options.preInstall, - options.uninstall, - ) + >(options: { current: VersionInfo; other: OtherVersions }) { + return new VersionGraph(options.current, options.other) } async migrate({ effects, @@ -270,7 +252,7 @@ export class VersionGraph .normalize(), ) - async init(effects: T.Effects, kind: InitKind): Promise { + async init(effects: T.Effects): Promise { const from = await getDataVersion(effects) if (from) { await this.migrate({ @@ -279,10 +261,6 @@ export class VersionGraph to: this.currentVersion(), }) } else { - kind = 'install' // implied by !dataVersion - if (this.preInstall) - if ('init' in this.preInstall) await this.preInstall.init(effects, kind) - else await this.preInstall(effects, kind) await effects.setDataVersion({ version: this.current.options.version }) } } @@ -300,11 +278,6 @@ export class VersionGraph to: target, }) } - } else { - if (this.uninstall) - if ('uninit' in this.uninstall) - await this.uninstall.uninit(effects, target) - else await this.uninstall(effects, target) } await setDataVersion(effects, target) }