use std::collections::{BTreeMap, BTreeSet}; use std::net::SocketAddrV4; use std::path::PathBuf; use std::time::Duration; use axum::extract::ws; use clap::Parser; use imbl::{HashMap, OrdMap}; use imbl_value::InternedString; use itertools::Itertools; use patch_db::Dump; use patch_db::json_ptr::{JsonPointer, ROOT}; use rpc_toolkit::yajrc::RpcError; use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async}; use serde::{Deserialize, Serialize}; use tracing::instrument; use ts_rs::TS; use crate::GatewayId; use crate::auth::Sessions; use crate::context::CliContext; use crate::db::model::public::NetworkInterfaceInfo; use crate::prelude::*; use crate::rpc_continuations::{Guid, RpcContinuation}; use crate::sign::AnyVerifyingKey; use crate::tunnel::auth::SignerInfo; use crate::tunnel::context::TunnelContext; use crate::tunnel::web::WebserverInfo; use crate::tunnel::wg::WgServer; use crate::util::net::WebSocketExt; use crate::util::serde::{HandlerExtSerde, apply_expr}; #[derive(Default, Deserialize, Serialize, HasModel, TS)] #[serde(rename_all = "camelCase")] #[model = "Model"] pub struct TunnelDatabase { pub webserver: WebserverInfo, pub sessions: Sessions, pub password: Option, #[ts(as = "std::collections::HashMap::")] pub auth_pubkeys: HashMap, #[ts(as = "std::collections::BTreeMap::")] pub gateways: OrdMap, pub wg: WgServer, pub port_forwards: PortForwards, } impl Model { pub fn gc_forwards(&mut self) -> Result, Error> { let mut keep_sources = BTreeSet::new(); let mut keep_targets = BTreeSet::new(); for (_, cfg) in self.as_wg().as_subnets().as_entries()? { keep_targets.extend(cfg.as_clients().keys()?); } self.as_port_forwards_mut().mutate(|pf| { Ok(pf.0.retain(|k, v| { if keep_targets.contains(v.ip()) { keep_sources.insert(*k); true } else { false } })) })?; Ok(keep_sources) } } #[test] fn export_bindings_tunnel_db() { TunnelDatabase::export_all_to("bindings/tunnel").unwrap(); } #[derive(Clone, Debug, Default, Deserialize, Serialize, TS)] pub struct PortForwards(pub BTreeMap); impl Map for PortForwards { type Key = SocketAddrV4; type Value = SocketAddrV4; fn key_str(key: &Self::Key) -> Result, Error> { Self::key_string(key) } fn key_string(key: &Self::Key) -> Result { Ok(InternedString::from_display(key)) } } pub fn db_api() -> ParentHandler { ParentHandler::new() .subcommand( "dump", from_fn_async(cli_dump) .with_display_serializable() .with_about("Filter/query db to display tables and records"), ) .subcommand( "dump", from_fn_async(dump) .with_metadata("admin", Value::Bool(true)) .no_cli(), ) .subcommand( "subscribe", from_fn_async(subscribe) .with_metadata("get_session", Value::Bool(true)) .no_cli(), ) .subcommand( "apply", from_fn_async(cli_apply) .no_display() .with_about("Update a db record"), ) .subcommand( "apply", from_fn_async(apply) .with_metadata("admin", Value::Bool(true)) .no_cli(), ) } #[derive(Deserialize, Serialize, Parser)] #[serde(rename_all = "camelCase")] #[command(rename_all = "kebab-case")] pub struct CliDumpParams { #[arg(long = "pointer", short = 'p')] pointer: Option, path: Option, } #[instrument(skip_all)] async fn cli_dump( HandlerArgs { context, parent_method, method, params: CliDumpParams { pointer, path }, .. }: HandlerArgs, ) -> Result { let dump = if let Some(path) = path { PatchDb::open(path).await?.dump(&ROOT).await } else { let method = parent_method.into_iter().chain(method).join("."); from_value::( context .call_remote::(&method, imbl_value::json!({ "pointer": pointer })) .await?, )? }; Ok(dump) } #[derive(Deserialize, Serialize, Parser, TS)] #[serde(rename_all = "camelCase")] #[command(rename_all = "kebab-case")] pub struct DumpParams { #[arg(long = "pointer", short = 'p')] #[ts(type = "string | null")] pointer: Option, } pub async fn dump(ctx: TunnelContext, DumpParams { pointer }: DumpParams) -> Result { Ok(ctx .db .dump(&pointer.as_ref().map_or(ROOT, |p| p.borrowed())) .await) } #[derive(Deserialize, Serialize, Parser)] #[serde(rename_all = "camelCase")] #[command(rename_all = "kebab-case")] pub struct CliApplyParams { expr: String, path: Option, } #[instrument(skip_all)] async fn cli_apply( HandlerArgs { context, parent_method, method, params: CliApplyParams { expr, path }, .. }: HandlerArgs, ) -> Result<(), RpcError> { if let Some(path) = path { PatchDb::open(path) .await? .apply_function(|db| { let res = apply_expr( serde_json::to_value(patch_db::Value::from(db)) .with_kind(ErrorKind::Deserialization)? .into(), &expr, )?; Ok::<_, Error>(( to_value( &serde_json::from_value::(res.clone().into()).with_ctx( |_| { ( crate::ErrorKind::Deserialization, "result does not match database model", ) }, )?, )?, (), )) }) .await .result?; } else { let method = parent_method.into_iter().chain(method).join("."); context .call_remote::(&method, imbl_value::json!({ "expr": expr })) .await?; } Ok(()) } #[derive(Deserialize, Serialize, Parser, TS)] #[serde(rename_all = "camelCase")] #[command(rename_all = "kebab-case")] pub struct ApplyParams { expr: String, path: Option, } pub async fn apply(ctx: TunnelContext, ApplyParams { expr, .. }: ApplyParams) -> Result<(), Error> { ctx.db .mutate(|db| { let res = apply_expr( serde_json::to_value(patch_db::Value::from(db.clone())) .with_kind(ErrorKind::Deserialization)? .into(), &expr, )?; db.ser( &serde_json::from_value::(res.clone().into()).with_ctx(|_| { ( crate::ErrorKind::Deserialization, "result does not match database model", ) })?, ) }) .await .result } #[derive(Deserialize, Serialize, TS)] #[serde(rename_all = "camelCase")] pub struct SubscribeParams { #[ts(type = "string | null")] pointer: Option, #[ts(skip)] #[serde(rename = "__Auth_session")] session: Option, } #[derive(Deserialize, Serialize, TS)] #[serde(rename_all = "camelCase")] pub struct SubscribeRes { #[ts(type = "{ id: number; value: unknown }")] pub dump: Dump, pub guid: Guid, } pub async fn subscribe( ctx: TunnelContext, SubscribeParams { pointer, session }: SubscribeParams, ) -> Result { let (dump, mut sub) = ctx .db .dump_and_sub(pointer.unwrap_or_else(|| ROOT.to_owned())) .await; let guid = Guid::new(); ctx.rpc_continuations .add( guid.clone(), RpcContinuation::ws_authed( &ctx, session, |mut ws| async move { if let Err(e) = async { loop { tokio::select! { rev = sub.recv() => { if let Some(rev) = rev { ws.send(ws::Message::Text( serde_json::to_string(&rev) .with_kind(ErrorKind::Serialization)? .into(), )) .await .with_kind(ErrorKind::Network)?; } else { return ws.normal_close("complete").await; } } msg = ws.recv() => { if msg.transpose().with_kind(ErrorKind::Network)?.is_none() { return Ok(()) } } } } } .await { tracing::error!("Error in db websocket: {e}"); tracing::debug!("{e:?}"); } }, Duration::from_secs(30), ), ) .await; Ok(SubscribeRes { dump, guid }) }