mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-03-26 02:11:53 +00:00
improve StartTunnel validation and GC (#3062)
* improve StartTunnel validation and GC * update sdk formatting
This commit is contained in:
@@ -175,15 +175,16 @@ pub async fn remove_subnet(
|
||||
_: Empty,
|
||||
SubnetParams { subnet }: SubnetParams,
|
||||
) -> Result<(), Error> {
|
||||
let server = ctx
|
||||
let (server, keep) = ctx
|
||||
.db
|
||||
.mutate(|db| {
|
||||
db.as_wg_mut().as_subnets_mut().remove(&subnet)?;
|
||||
db.as_wg().de()
|
||||
Ok((db.as_wg().de()?, db.gc_forwards()?))
|
||||
})
|
||||
.await
|
||||
.result?;
|
||||
server.sync().await
|
||||
server.sync().await?;
|
||||
ctx.gc_forwards(&keep).await
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Parser)]
|
||||
@@ -258,7 +259,7 @@ pub async fn remove_device(
|
||||
ctx: TunnelContext,
|
||||
RemoveDeviceParams { subnet, ip }: RemoveDeviceParams,
|
||||
) -> Result<(), Error> {
|
||||
let server = ctx
|
||||
let (server, keep) = ctx
|
||||
.db
|
||||
.mutate(|db| {
|
||||
db.as_wg_mut()
|
||||
@@ -268,11 +269,12 @@ pub async fn remove_device(
|
||||
.as_clients_mut()
|
||||
.remove(&ip)?
|
||||
.or_not_found(&ip)?;
|
||||
db.as_wg().de()
|
||||
Ok((db.as_wg().de()?, db.gc_forwards()?))
|
||||
})
|
||||
.await
|
||||
.result?;
|
||||
server.sync().await
|
||||
server.sync().await?;
|
||||
ctx.gc_forwards(&keep).await
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Parser)]
|
||||
@@ -377,7 +379,20 @@ pub async fn add_forward(
|
||||
});
|
||||
|
||||
ctx.db
|
||||
.mutate(|db| db.as_port_forwards_mut().insert(&source, &target))
|
||||
.mutate(|db| {
|
||||
db.as_port_forwards_mut()
|
||||
.insert(&source, &target)
|
||||
.and_then(|replaced| {
|
||||
if replaced.is_some() {
|
||||
Err(Error::new(
|
||||
eyre!("Port forward from {source} already exists"),
|
||||
ErrorKind::InvalidRequest,
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
})
|
||||
.await
|
||||
.result?;
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::net::{IpAddr, SocketAddr, SocketAddrV4};
|
||||
use std::ops::Deref;
|
||||
use std::path::{Path, PathBuf};
|
||||
@@ -9,7 +9,6 @@ use cookie::{Cookie, Expiration, SameSite};
|
||||
use http::HeaderMap;
|
||||
use imbl::OrdMap;
|
||||
use imbl_value::InternedString;
|
||||
use include_dir::Dir;
|
||||
use models::GatewayId;
|
||||
use patch_db::PatchDb;
|
||||
use rpc_toolkit::yajrc::RpcError;
|
||||
@@ -192,6 +191,12 @@ impl TunnelContext {
|
||||
shutdown,
|
||||
})))
|
||||
}
|
||||
|
||||
pub async fn gc_forwards(&self, keep: &BTreeSet<SocketAddrV4>) -> Result<(), Error> {
|
||||
self.active_forwards
|
||||
.mutate(|pf| pf.retain(|k, _| keep.contains(k)));
|
||||
self.forward.gc().await
|
||||
}
|
||||
}
|
||||
impl AsRef<RpcContinuations> for TunnelContext {
|
||||
fn as_ref(&self) -> &RpcContinuations {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::net::SocketAddrV4;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
@@ -45,6 +45,27 @@ pub struct TunnelDatabase {
|
||||
pub port_forwards: PortForwards,
|
||||
}
|
||||
|
||||
impl Model<TunnelDatabase> {
|
||||
pub fn gc_forwards(&mut self) -> Result<BTreeSet<SocketAddrV4>, Error> {
|
||||
let mut keep_sources = BTreeSet::new();
|
||||
let mut keep_targets = BTreeSet::new();
|
||||
for (_, cfg) in self.as_wg().as_subnets().as_entries()? {
|
||||
keep_targets.extend(cfg.as_clients().keys()?);
|
||||
}
|
||||
self.as_port_forwards_mut().mutate(|pf| {
|
||||
Ok(pf.0.retain(|k, v| {
|
||||
if keep_targets.contains(v.ip()) {
|
||||
keep_sources.insert(*k);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}))
|
||||
})?;
|
||||
Ok(keep_sources)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn export_bindings_tunnel_db() {
|
||||
TunnelDatabase::export_all_to("bindings/tunnel").unwrap();
|
||||
|
||||
@@ -77,17 +77,18 @@ where
|
||||
.log_err()?;
|
||||
let cert_key = cert_info.key.0.private_key_to_pkcs8().log_err()?;
|
||||
|
||||
Some(
|
||||
ServerConfig::builder_with_provider(self.crypto_provider.clone())
|
||||
.with_safe_default_protocol_versions()
|
||||
.log_err()?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(
|
||||
cert_chain,
|
||||
PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert_key)),
|
||||
)
|
||||
.log_err()?,
|
||||
)
|
||||
let mut cfg = ServerConfig::builder_with_provider(self.crypto_provider.clone())
|
||||
.with_safe_default_protocol_versions()
|
||||
.log_err()?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(
|
||||
cert_chain,
|
||||
PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert_key)),
|
||||
)
|
||||
.log_err()?;
|
||||
cfg.alpn_protocols
|
||||
.extend([b"http/1.1".into(), b"h2".into()]);
|
||||
Some(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -45,9 +45,10 @@ export interface ActionInfo<
|
||||
readonly _INPUT: Type
|
||||
}
|
||||
|
||||
export class Action<Id extends T.ActionId, Type extends Record<string, any>>
|
||||
implements ActionInfo<Id, Type>
|
||||
{
|
||||
export class Action<
|
||||
Id extends T.ActionId,
|
||||
Type extends Record<string, any>,
|
||||
> implements ActionInfo<Id, Type> {
|
||||
readonly _INPUT: Type = null as any as Type
|
||||
private prevInputSpec: Record<
|
||||
string,
|
||||
@@ -148,8 +149,7 @@ export class Action<Id extends T.ActionId, Type extends Record<string, any>>
|
||||
|
||||
export class Actions<
|
||||
AllActions extends Record<T.ActionId, Action<T.ActionId, any>>,
|
||||
> implements InitScript
|
||||
{
|
||||
> implements InitScript {
|
||||
private constructor(private readonly actions: AllActions) {}
|
||||
static of(): Actions<{}> {
|
||||
return new Actions({})
|
||||
|
||||
@@ -109,9 +109,11 @@ export class DropPromise<T> implements Promise<T> {
|
||||
}
|
||||
}
|
||||
|
||||
export class DropGenerator<T = unknown, TReturn = any, TNext = unknown>
|
||||
implements AsyncGenerator<T, TReturn, TNext>
|
||||
{
|
||||
export class DropGenerator<
|
||||
T = unknown,
|
||||
TReturn = any,
|
||||
TNext = unknown,
|
||||
> implements AsyncGenerator<T, TReturn, TNext> {
|
||||
private static dropFns: { [id: number]: () => void } = {}
|
||||
private static registry = new FinalizationRegistry((id: number) => {
|
||||
const drop = DropGenerator.dropFns[id]
|
||||
|
||||
@@ -163,8 +163,8 @@ export class StartSdk<Manifest extends T.SDKManifest> {
|
||||
effects.action.clearTasks({ only: replayIds }),
|
||||
},
|
||||
checkDependencies: checkDependencies as <
|
||||
DependencyId extends keyof Manifest["dependencies"] &
|
||||
PackageId = keyof Manifest["dependencies"] & PackageId,
|
||||
DependencyId extends keyof Manifest["dependencies"] & PackageId =
|
||||
keyof Manifest["dependencies"] & PackageId,
|
||||
>(
|
||||
effects: Effects,
|
||||
packageIds?: DependencyId[],
|
||||
|
||||
@@ -137,9 +137,9 @@ export interface SubContainer<
|
||||
* Want to limit what we can do in a container, so we want to launch a container with a specific image and the mounts.
|
||||
*/
|
||||
export class SubContainerOwned<
|
||||
Manifest extends T.SDKManifest,
|
||||
Effects extends T.Effects = T.Effects,
|
||||
>
|
||||
Manifest extends T.SDKManifest,
|
||||
Effects extends T.Effects = T.Effects,
|
||||
>
|
||||
extends Drop
|
||||
implements SubContainer<Manifest, Effects>
|
||||
{
|
||||
@@ -615,9 +615,9 @@ export class SubContainerOwned<
|
||||
}
|
||||
|
||||
export class SubContainerRc<
|
||||
Manifest extends T.SDKManifest,
|
||||
Effects extends T.Effects = T.Effects,
|
||||
>
|
||||
Manifest extends T.SDKManifest,
|
||||
Effects extends T.Effects = T.Effects,
|
||||
>
|
||||
extends Drop
|
||||
implements SubContainer<Manifest, Effects>
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user