use std::collections::BTreeMap; use std::net::SocketAddrV4; use std::path::PathBuf; use clap::Parser; use clap::builder::ValueParserFactory; use imbl::{HashMap, OrdMap}; use imbl_value::InternedString; use itertools::Itertools; use models::{FromStrParser, GatewayId}; use patch_db::Dump; use patch_db::json_ptr::{JsonPointer, ROOT}; use rpc_toolkit::yajrc::RpcError; use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async}; use serde::{Deserialize, Serialize}; use tracing::instrument; use ts_rs::TS; use crate::auth::Sessions; use crate::context::CliContext; use crate::db::model::public::NetworkInterfaceInfo; use crate::prelude::*; use crate::sign::AnyVerifyingKey; use crate::tunnel::auth::SignerInfo; use crate::tunnel::context::TunnelContext; use crate::tunnel::web::WebserverInfo; use crate::tunnel::wg::WgServer; use crate::util::serde::{HandlerExtSerde, apply_expr, deserialize_from_str, serialize_display}; #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct GatewayPort(pub GatewayId, pub u16); impl std::fmt::Display for GatewayPort { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}:{}", self.0, self.1) } } impl std::str::FromStr for GatewayPort { type Err = crate::Error; fn from_str(s: &str) -> Result { let mut parts = s.splitn(2, ':'); let gw: GatewayId = parts .next() .ok_or_else(|| Error::new(eyre!("missing gateway id"), ErrorKind::ParseNetAddress))? .parse()?; let port: u16 = parts .next() .ok_or_else(|| Error::new(eyre!("missing port"), ErrorKind::ParseNetAddress))? .parse()?; Ok(GatewayPort(gw, port)) } } impl Serialize for GatewayPort { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { serialize_display(self, serializer) } } impl<'de> Deserialize<'de> for GatewayPort { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { deserialize_from_str(deserializer) } } impl ValueParserFactory for GatewayPort { type Parser = FromStrParser; fn value_parser() -> Self::Parser { FromStrParser::new() } } #[derive(Default, Deserialize, Serialize, HasModel)] #[serde(rename_all = "camelCase")] #[model = "Model"] pub struct TunnelDatabase { pub webserver: WebserverInfo, pub sessions: Sessions, pub password: Option, pub auth_pubkeys: HashMap, pub gateways: OrdMap, pub wg: WgServer, pub port_forwards: PortForwards, } #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct PortForwards(pub BTreeMap); impl Map for PortForwards { type Key = GatewayPort; type Value = SocketAddrV4; fn key_str(key: &Self::Key) -> Result, Error> { Self::key_string(key) } fn key_string(key: &Self::Key) -> Result { Ok(InternedString::from_display(key)) } } pub fn db_api() -> ParentHandler { ParentHandler::new() .subcommand( "dump", from_fn_async(cli_dump) .with_display_serializable() .with_about("Filter/query db to display tables and records"), ) .subcommand( "dump", from_fn_async(dump) .with_metadata("admin", Value::Bool(true)) .no_cli(), ) .subcommand( "apply", from_fn_async(cli_apply) .no_display() .with_about("Update a db record"), ) .subcommand( "apply", from_fn_async(apply) .with_metadata("admin", Value::Bool(true)) .no_cli(), ) } #[derive(Deserialize, Serialize, Parser)] #[serde(rename_all = "camelCase")] #[command(rename_all = "kebab-case")] pub struct CliDumpParams { #[arg(long = "pointer", short = 'p')] pointer: Option, path: Option, } #[instrument(skip_all)] async fn cli_dump( HandlerArgs { context, parent_method, method, params: CliDumpParams { pointer, path }, .. }: HandlerArgs, ) -> Result { let dump = if let Some(path) = path { PatchDb::open(path).await?.dump(&ROOT).await } else { let method = parent_method.into_iter().chain(method).join("."); from_value::( context .call_remote::(&method, imbl_value::json!({ "pointer": pointer })) .await?, )? }; Ok(dump) } #[derive(Deserialize, Serialize, Parser, TS)] #[serde(rename_all = "camelCase")] #[command(rename_all = "kebab-case")] pub struct DumpParams { #[arg(long = "pointer", short = 'p')] #[ts(type = "string | null")] pointer: Option, } pub async fn dump(ctx: TunnelContext, DumpParams { pointer }: DumpParams) -> Result { Ok(ctx .db .dump(&pointer.as_ref().map_or(ROOT, |p| p.borrowed())) .await) } #[derive(Deserialize, Serialize, Parser)] #[serde(rename_all = "camelCase")] #[command(rename_all = "kebab-case")] pub struct CliApplyParams { expr: String, path: Option, } #[instrument(skip_all)] async fn cli_apply( HandlerArgs { context, parent_method, method, params: CliApplyParams { expr, path }, .. }: HandlerArgs, ) -> Result<(), RpcError> { if let Some(path) = path { PatchDb::open(path) .await? .apply_function(|db| { let res = apply_expr( serde_json::to_value(patch_db::Value::from(db)) .with_kind(ErrorKind::Deserialization)? .into(), &expr, )?; Ok::<_, Error>(( to_value( &serde_json::from_value::(res.clone().into()).with_ctx( |_| { ( crate::ErrorKind::Deserialization, "result does not match database model", ) }, )?, )?, (), )) }) .await .result?; } else { let method = parent_method.into_iter().chain(method).join("."); context .call_remote::(&method, imbl_value::json!({ "expr": expr })) .await?; } Ok(()) } #[derive(Deserialize, Serialize, Parser, TS)] #[serde(rename_all = "camelCase")] #[command(rename_all = "kebab-case")] pub struct ApplyParams { expr: String, path: Option, } pub async fn apply(ctx: TunnelContext, ApplyParams { expr, .. }: ApplyParams) -> Result<(), Error> { ctx.db .mutate(|db| { let res = apply_expr( serde_json::to_value(patch_db::Value::from(db.clone())) .with_kind(ErrorKind::Deserialization)? .into(), &expr, )?; db.ser( &serde_json::from_value::(res.clone().into()).with_ctx(|_| { ( crate::ErrorKind::Deserialization, "result does not match database model", ) })?, ) }) .await .result }