improve StartTunnel validation and GC (#3062)

* improve StartTunnel validation and GC

* update sdk formatting
This commit is contained in:
Aiden McClelland
2025-11-28 13:14:52 -07:00
committed by GitHub
parent 72eb8b1eb6
commit a53b15f2a3
8 changed files with 81 additions and 37 deletions

View File

@@ -175,15 +175,16 @@ pub async fn remove_subnet(
_: Empty, _: Empty,
SubnetParams { subnet }: SubnetParams, SubnetParams { subnet }: SubnetParams,
) -> Result<(), Error> { ) -> Result<(), Error> {
let server = ctx let (server, keep) = ctx
.db .db
.mutate(|db| { .mutate(|db| {
db.as_wg_mut().as_subnets_mut().remove(&subnet)?; db.as_wg_mut().as_subnets_mut().remove(&subnet)?;
db.as_wg().de() Ok((db.as_wg().de()?, db.gc_forwards()?))
}) })
.await .await
.result?; .result?;
server.sync().await server.sync().await?;
ctx.gc_forwards(&keep).await
} }
#[derive(Deserialize, Serialize, Parser)] #[derive(Deserialize, Serialize, Parser)]
@@ -258,7 +259,7 @@ pub async fn remove_device(
ctx: TunnelContext, ctx: TunnelContext,
RemoveDeviceParams { subnet, ip }: RemoveDeviceParams, RemoveDeviceParams { subnet, ip }: RemoveDeviceParams,
) -> Result<(), Error> { ) -> Result<(), Error> {
let server = ctx let (server, keep) = ctx
.db .db
.mutate(|db| { .mutate(|db| {
db.as_wg_mut() db.as_wg_mut()
@@ -268,11 +269,12 @@ pub async fn remove_device(
.as_clients_mut() .as_clients_mut()
.remove(&ip)? .remove(&ip)?
.or_not_found(&ip)?; .or_not_found(&ip)?;
db.as_wg().de() Ok((db.as_wg().de()?, db.gc_forwards()?))
}) })
.await .await
.result?; .result?;
server.sync().await server.sync().await?;
ctx.gc_forwards(&keep).await
} }
#[derive(Deserialize, Serialize, Parser)] #[derive(Deserialize, Serialize, Parser)]
@@ -377,7 +379,20 @@ pub async fn add_forward(
}); });
ctx.db ctx.db
.mutate(|db| db.as_port_forwards_mut().insert(&source, &target)) .mutate(|db| {
db.as_port_forwards_mut()
.insert(&source, &target)
.and_then(|replaced| {
if replaced.is_some() {
Err(Error::new(
eyre!("Port forward from {source} already exists"),
ErrorKind::InvalidRequest,
))
} else {
Ok(())
}
})
})
.await .await
.result?; .result?;

View File

@@ -1,4 +1,4 @@
use std::collections::BTreeMap; use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, SocketAddr, SocketAddrV4}; use std::net::{IpAddr, SocketAddr, SocketAddrV4};
use std::ops::Deref; use std::ops::Deref;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
@@ -9,7 +9,6 @@ use cookie::{Cookie, Expiration, SameSite};
use http::HeaderMap; use http::HeaderMap;
use imbl::OrdMap; use imbl::OrdMap;
use imbl_value::InternedString; use imbl_value::InternedString;
use include_dir::Dir;
use models::GatewayId; use models::GatewayId;
use patch_db::PatchDb; use patch_db::PatchDb;
use rpc_toolkit::yajrc::RpcError; use rpc_toolkit::yajrc::RpcError;
@@ -192,6 +191,12 @@ impl TunnelContext {
shutdown, shutdown,
}))) })))
} }
pub async fn gc_forwards(&self, keep: &BTreeSet<SocketAddrV4>) -> Result<(), Error> {
self.active_forwards
.mutate(|pf| pf.retain(|k, _| keep.contains(k)));
self.forward.gc().await
}
} }
impl AsRef<RpcContinuations> for TunnelContext { impl AsRef<RpcContinuations> for TunnelContext {
fn as_ref(&self) -> &RpcContinuations { fn as_ref(&self) -> &RpcContinuations {

View File

@@ -1,4 +1,4 @@
use std::collections::BTreeMap; use std::collections::{BTreeMap, BTreeSet};
use std::net::SocketAddrV4; use std::net::SocketAddrV4;
use std::path::PathBuf; use std::path::PathBuf;
use std::time::Duration; use std::time::Duration;
@@ -45,6 +45,27 @@ pub struct TunnelDatabase {
pub port_forwards: PortForwards, pub port_forwards: PortForwards,
} }
impl Model<TunnelDatabase> {
pub fn gc_forwards(&mut self) -> Result<BTreeSet<SocketAddrV4>, Error> {
let mut keep_sources = BTreeSet::new();
let mut keep_targets = BTreeSet::new();
for (_, cfg) in self.as_wg().as_subnets().as_entries()? {
keep_targets.extend(cfg.as_clients().keys()?);
}
self.as_port_forwards_mut().mutate(|pf| {
Ok(pf.0.retain(|k, v| {
if keep_targets.contains(v.ip()) {
keep_sources.insert(*k);
true
} else {
false
}
}))
})?;
Ok(keep_sources)
}
}
#[test] #[test]
fn export_bindings_tunnel_db() { fn export_bindings_tunnel_db() {
TunnelDatabase::export_all_to("bindings/tunnel").unwrap(); TunnelDatabase::export_all_to("bindings/tunnel").unwrap();

View File

@@ -77,17 +77,18 @@ where
.log_err()?; .log_err()?;
let cert_key = cert_info.key.0.private_key_to_pkcs8().log_err()?; let cert_key = cert_info.key.0.private_key_to_pkcs8().log_err()?;
Some( let mut cfg = ServerConfig::builder_with_provider(self.crypto_provider.clone())
ServerConfig::builder_with_provider(self.crypto_provider.clone()) .with_safe_default_protocol_versions()
.with_safe_default_protocol_versions() .log_err()?
.log_err()? .with_no_client_auth()
.with_no_client_auth() .with_single_cert(
.with_single_cert( cert_chain,
cert_chain, PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert_key)),
PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert_key)), )
) .log_err()?;
.log_err()?, cfg.alpn_protocols
) .extend([b"http/1.1".into(), b"h2".into()]);
Some(cfg)
} }
} }

View File

@@ -45,9 +45,10 @@ export interface ActionInfo<
readonly _INPUT: Type readonly _INPUT: Type
} }
export class Action<Id extends T.ActionId, Type extends Record<string, any>> export class Action<
implements ActionInfo<Id, Type> Id extends T.ActionId,
{ Type extends Record<string, any>,
> implements ActionInfo<Id, Type> {
readonly _INPUT: Type = null as any as Type readonly _INPUT: Type = null as any as Type
private prevInputSpec: Record< private prevInputSpec: Record<
string, string,
@@ -148,8 +149,7 @@ export class Action<Id extends T.ActionId, Type extends Record<string, any>>
export class Actions< export class Actions<
AllActions extends Record<T.ActionId, Action<T.ActionId, any>>, AllActions extends Record<T.ActionId, Action<T.ActionId, any>>,
> implements InitScript > implements InitScript {
{
private constructor(private readonly actions: AllActions) {} private constructor(private readonly actions: AllActions) {}
static of(): Actions<{}> { static of(): Actions<{}> {
return new Actions({}) return new Actions({})

View File

@@ -109,9 +109,11 @@ export class DropPromise<T> implements Promise<T> {
} }
} }
export class DropGenerator<T = unknown, TReturn = any, TNext = unknown> export class DropGenerator<
implements AsyncGenerator<T, TReturn, TNext> T = unknown,
{ TReturn = any,
TNext = unknown,
> implements AsyncGenerator<T, TReturn, TNext> {
private static dropFns: { [id: number]: () => void } = {} private static dropFns: { [id: number]: () => void } = {}
private static registry = new FinalizationRegistry((id: number) => { private static registry = new FinalizationRegistry((id: number) => {
const drop = DropGenerator.dropFns[id] const drop = DropGenerator.dropFns[id]

View File

@@ -163,8 +163,8 @@ export class StartSdk<Manifest extends T.SDKManifest> {
effects.action.clearTasks({ only: replayIds }), effects.action.clearTasks({ only: replayIds }),
}, },
checkDependencies: checkDependencies as < checkDependencies: checkDependencies as <
DependencyId extends keyof Manifest["dependencies"] & DependencyId extends keyof Manifest["dependencies"] & PackageId =
PackageId = keyof Manifest["dependencies"] & PackageId, keyof Manifest["dependencies"] & PackageId,
>( >(
effects: Effects, effects: Effects,
packageIds?: DependencyId[], packageIds?: DependencyId[],

View File

@@ -137,9 +137,9 @@ export interface SubContainer<
* Want to limit what we can do in a container, so we want to launch a container with a specific image and the mounts. * Want to limit what we can do in a container, so we want to launch a container with a specific image and the mounts.
*/ */
export class SubContainerOwned< export class SubContainerOwned<
Manifest extends T.SDKManifest, Manifest extends T.SDKManifest,
Effects extends T.Effects = T.Effects, Effects extends T.Effects = T.Effects,
> >
extends Drop extends Drop
implements SubContainer<Manifest, Effects> implements SubContainer<Manifest, Effects>
{ {
@@ -615,9 +615,9 @@ export class SubContainerOwned<
} }
export class SubContainerRc< export class SubContainerRc<
Manifest extends T.SDKManifest, Manifest extends T.SDKManifest,
Effects extends T.Effects = T.Effects, Effects extends T.Effects = T.Effects,
> >
extends Drop extends Drop
implements SubContainer<Manifest, Effects> implements SubContainer<Manifest, Effects>
{ {