diff --git a/core/startos/src/tunnel/api.rs b/core/startos/src/tunnel/api.rs index 5b88a5464..0e32f5452 100644 --- a/core/startos/src/tunnel/api.rs +++ b/core/startos/src/tunnel/api.rs @@ -175,15 +175,16 @@ pub async fn remove_subnet( _: Empty, SubnetParams { subnet }: SubnetParams, ) -> Result<(), Error> { - let server = ctx + let (server, keep) = ctx .db .mutate(|db| { db.as_wg_mut().as_subnets_mut().remove(&subnet)?; - db.as_wg().de() + Ok((db.as_wg().de()?, db.gc_forwards()?)) }) .await .result?; - server.sync().await + server.sync().await?; + ctx.gc_forwards(&keep).await } #[derive(Deserialize, Serialize, Parser)] @@ -258,7 +259,7 @@ pub async fn remove_device( ctx: TunnelContext, RemoveDeviceParams { subnet, ip }: RemoveDeviceParams, ) -> Result<(), Error> { - let server = ctx + let (server, keep) = ctx .db .mutate(|db| { db.as_wg_mut() @@ -268,11 +269,12 @@ pub async fn remove_device( .as_clients_mut() .remove(&ip)? .or_not_found(&ip)?; - db.as_wg().de() + Ok((db.as_wg().de()?, db.gc_forwards()?)) }) .await .result?; - server.sync().await + server.sync().await?; + ctx.gc_forwards(&keep).await } #[derive(Deserialize, Serialize, Parser)] @@ -377,7 +379,20 @@ pub async fn add_forward( }); ctx.db - .mutate(|db| db.as_port_forwards_mut().insert(&source, &target)) + .mutate(|db| { + db.as_port_forwards_mut() + .insert(&source, &target) + .and_then(|replaced| { + if replaced.is_some() { + Err(Error::new( + eyre!("Port forward from {source} already exists"), + ErrorKind::InvalidRequest, + )) + } else { + Ok(()) + } + }) + }) .await .result?; diff --git a/core/startos/src/tunnel/context.rs b/core/startos/src/tunnel/context.rs index 421e8e2e6..c6b74fc70 100644 --- a/core/startos/src/tunnel/context.rs +++ b/core/startos/src/tunnel/context.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use std::net::{IpAddr, SocketAddr, SocketAddrV4}; use std::ops::Deref; use std::path::{Path, PathBuf}; @@ -9,7 +9,6 @@ use cookie::{Cookie, Expiration, SameSite}; use http::HeaderMap; use imbl::OrdMap; use imbl_value::InternedString; -use include_dir::Dir; use models::GatewayId; use patch_db::PatchDb; use rpc_toolkit::yajrc::RpcError; @@ -192,6 +191,12 @@ impl TunnelContext { shutdown, }))) } + + pub async fn gc_forwards(&self, keep: &BTreeSet) -> Result<(), Error> { + self.active_forwards + .mutate(|pf| pf.retain(|k, _| keep.contains(k))); + self.forward.gc().await + } } impl AsRef for TunnelContext { fn as_ref(&self) -> &RpcContinuations { diff --git a/core/startos/src/tunnel/db.rs b/core/startos/src/tunnel/db.rs index cbdbd46d2..9c6b48a7c 100644 --- a/core/startos/src/tunnel/db.rs +++ b/core/startos/src/tunnel/db.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use std::net::SocketAddrV4; use std::path::PathBuf; use std::time::Duration; @@ -45,6 +45,27 @@ pub struct TunnelDatabase { pub port_forwards: PortForwards, } +impl Model { + pub fn gc_forwards(&mut self) -> Result, Error> { + let mut keep_sources = BTreeSet::new(); + let mut keep_targets = BTreeSet::new(); + for (_, cfg) in self.as_wg().as_subnets().as_entries()? { + keep_targets.extend(cfg.as_clients().keys()?); + } + self.as_port_forwards_mut().mutate(|pf| { + Ok(pf.0.retain(|k, v| { + if keep_targets.contains(v.ip()) { + keep_sources.insert(*k); + true + } else { + false + } + })) + })?; + Ok(keep_sources) + } +} + #[test] fn export_bindings_tunnel_db() { TunnelDatabase::export_all_to("bindings/tunnel").unwrap(); diff --git a/core/startos/src/tunnel/web.rs b/core/startos/src/tunnel/web.rs index 121b36a3e..f38f37876 100644 --- a/core/startos/src/tunnel/web.rs +++ b/core/startos/src/tunnel/web.rs @@ -77,17 +77,18 @@ where .log_err()?; let cert_key = cert_info.key.0.private_key_to_pkcs8().log_err()?; - Some( - ServerConfig::builder_with_provider(self.crypto_provider.clone()) - .with_safe_default_protocol_versions() - .log_err()? - .with_no_client_auth() - .with_single_cert( - cert_chain, - PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert_key)), - ) - .log_err()?, - ) + let mut cfg = ServerConfig::builder_with_provider(self.crypto_provider.clone()) + .with_safe_default_protocol_versions() + .log_err()? + .with_no_client_auth() + .with_single_cert( + cert_chain, + PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert_key)), + ) + .log_err()?; + cfg.alpn_protocols + .extend([b"http/1.1".into(), b"h2".into()]); + Some(cfg) } } diff --git a/sdk/base/lib/actions/setupActions.ts b/sdk/base/lib/actions/setupActions.ts index 26a72a117..d8ab040bd 100644 --- a/sdk/base/lib/actions/setupActions.ts +++ b/sdk/base/lib/actions/setupActions.ts @@ -45,9 +45,10 @@ export interface ActionInfo< readonly _INPUT: Type } -export class Action> - implements ActionInfo -{ +export class Action< + Id extends T.ActionId, + Type extends Record, +> implements ActionInfo { readonly _INPUT: Type = null as any as Type private prevInputSpec: Record< string, @@ -148,8 +149,7 @@ export class Action> export class Actions< AllActions extends Record>, -> implements InitScript -{ +> implements InitScript { private constructor(private readonly actions: AllActions) {} static of(): Actions<{}> { return new Actions({}) diff --git a/sdk/base/lib/util/Drop.ts b/sdk/base/lib/util/Drop.ts index 62dd61f0f..cf4d9f302 100644 --- a/sdk/base/lib/util/Drop.ts +++ b/sdk/base/lib/util/Drop.ts @@ -109,9 +109,11 @@ export class DropPromise implements Promise { } } -export class DropGenerator - implements AsyncGenerator -{ +export class DropGenerator< + T = unknown, + TReturn = any, + TNext = unknown, +> implements AsyncGenerator { private static dropFns: { [id: number]: () => void } = {} private static registry = new FinalizationRegistry((id: number) => { const drop = DropGenerator.dropFns[id] diff --git a/sdk/package/lib/StartSdk.ts b/sdk/package/lib/StartSdk.ts index 01788f00e..a86ce2562 100644 --- a/sdk/package/lib/StartSdk.ts +++ b/sdk/package/lib/StartSdk.ts @@ -163,8 +163,8 @@ export class StartSdk { effects.action.clearTasks({ only: replayIds }), }, checkDependencies: checkDependencies as < - DependencyId extends keyof Manifest["dependencies"] & - PackageId = keyof Manifest["dependencies"] & PackageId, + DependencyId extends keyof Manifest["dependencies"] & PackageId = + keyof Manifest["dependencies"] & PackageId, >( effects: Effects, packageIds?: DependencyId[], diff --git a/sdk/package/lib/util/SubContainer.ts b/sdk/package/lib/util/SubContainer.ts index 879c6bafb..c64ffc8ae 100644 --- a/sdk/package/lib/util/SubContainer.ts +++ b/sdk/package/lib/util/SubContainer.ts @@ -137,9 +137,9 @@ export interface SubContainer< * Want to limit what we can do in a container, so we want to launch a container with a specific image and the mounts. */ export class SubContainerOwned< - Manifest extends T.SDKManifest, - Effects extends T.Effects = T.Effects, - > + Manifest extends T.SDKManifest, + Effects extends T.Effects = T.Effects, +> extends Drop implements SubContainer { @@ -615,9 +615,9 @@ export class SubContainerOwned< } export class SubContainerRc< - Manifest extends T.SDKManifest, - Effects extends T.Effects = T.Effects, - > + Manifest extends T.SDKManifest, + Effects extends T.Effects = T.Effects, +> extends Drop implements SubContainer {