improve StartTunnel validation and GC (#3062)

* improve StartTunnel validation and GC

* update sdk formatting
This commit is contained in:
Aiden McClelland
2025-11-28 13:14:52 -07:00
committed by GitHub
parent 72eb8b1eb6
commit a53b15f2a3
8 changed files with 81 additions and 37 deletions

View File

@@ -175,15 +175,16 @@ pub async fn remove_subnet(
_: Empty,
SubnetParams { subnet }: SubnetParams,
) -> Result<(), Error> {
let server = ctx
let (server, keep) = ctx
.db
.mutate(|db| {
db.as_wg_mut().as_subnets_mut().remove(&subnet)?;
db.as_wg().de()
Ok((db.as_wg().de()?, db.gc_forwards()?))
})
.await
.result?;
server.sync().await
server.sync().await?;
ctx.gc_forwards(&keep).await
}
#[derive(Deserialize, Serialize, Parser)]
@@ -258,7 +259,7 @@ pub async fn remove_device(
ctx: TunnelContext,
RemoveDeviceParams { subnet, ip }: RemoveDeviceParams,
) -> Result<(), Error> {
let server = ctx
let (server, keep) = ctx
.db
.mutate(|db| {
db.as_wg_mut()
@@ -268,11 +269,12 @@ pub async fn remove_device(
.as_clients_mut()
.remove(&ip)?
.or_not_found(&ip)?;
db.as_wg().de()
Ok((db.as_wg().de()?, db.gc_forwards()?))
})
.await
.result?;
server.sync().await
server.sync().await?;
ctx.gc_forwards(&keep).await
}
#[derive(Deserialize, Serialize, Parser)]
@@ -377,7 +379,20 @@ pub async fn add_forward(
});
ctx.db
.mutate(|db| db.as_port_forwards_mut().insert(&source, &target))
.mutate(|db| {
db.as_port_forwards_mut()
.insert(&source, &target)
.and_then(|replaced| {
if replaced.is_some() {
Err(Error::new(
eyre!("Port forward from {source} already exists"),
ErrorKind::InvalidRequest,
))
} else {
Ok(())
}
})
})
.await
.result?;

View File

@@ -1,4 +1,4 @@
use std::collections::BTreeMap;
use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, SocketAddr, SocketAddrV4};
use std::ops::Deref;
use std::path::{Path, PathBuf};
@@ -9,7 +9,6 @@ use cookie::{Cookie, Expiration, SameSite};
use http::HeaderMap;
use imbl::OrdMap;
use imbl_value::InternedString;
use include_dir::Dir;
use models::GatewayId;
use patch_db::PatchDb;
use rpc_toolkit::yajrc::RpcError;
@@ -192,6 +191,12 @@ impl TunnelContext {
shutdown,
})))
}
pub async fn gc_forwards(&self, keep: &BTreeSet<SocketAddrV4>) -> Result<(), Error> {
self.active_forwards
.mutate(|pf| pf.retain(|k, _| keep.contains(k)));
self.forward.gc().await
}
}
impl AsRef<RpcContinuations> for TunnelContext {
fn as_ref(&self) -> &RpcContinuations {

View File

@@ -1,4 +1,4 @@
use std::collections::BTreeMap;
use std::collections::{BTreeMap, BTreeSet};
use std::net::SocketAddrV4;
use std::path::PathBuf;
use std::time::Duration;
@@ -45,6 +45,27 @@ pub struct TunnelDatabase {
pub port_forwards: PortForwards,
}
impl Model<TunnelDatabase> {
pub fn gc_forwards(&mut self) -> Result<BTreeSet<SocketAddrV4>, Error> {
let mut keep_sources = BTreeSet::new();
let mut keep_targets = BTreeSet::new();
for (_, cfg) in self.as_wg().as_subnets().as_entries()? {
keep_targets.extend(cfg.as_clients().keys()?);
}
self.as_port_forwards_mut().mutate(|pf| {
Ok(pf.0.retain(|k, v| {
if keep_targets.contains(v.ip()) {
keep_sources.insert(*k);
true
} else {
false
}
}))
})?;
Ok(keep_sources)
}
}
#[test]
fn export_bindings_tunnel_db() {
TunnelDatabase::export_all_to("bindings/tunnel").unwrap();

View File

@@ -77,17 +77,18 @@ where
.log_err()?;
let cert_key = cert_info.key.0.private_key_to_pkcs8().log_err()?;
Some(
ServerConfig::builder_with_provider(self.crypto_provider.clone())
.with_safe_default_protocol_versions()
.log_err()?
.with_no_client_auth()
.with_single_cert(
cert_chain,
PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert_key)),
)
.log_err()?,
)
let mut cfg = ServerConfig::builder_with_provider(self.crypto_provider.clone())
.with_safe_default_protocol_versions()
.log_err()?
.with_no_client_auth()
.with_single_cert(
cert_chain,
PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert_key)),
)
.log_err()?;
cfg.alpn_protocols
.extend([b"http/1.1".into(), b"h2".into()]);
Some(cfg)
}
}