diff --git a/core/startos/src/bins/tunnel.rs b/core/startos/src/bins/tunnel.rs index dde7d116d..3a2e2337a 100644 --- a/core/startos/src/bins/tunnel.rs +++ b/core/startos/src/bins/tunnel.rs @@ -46,7 +46,7 @@ async fn inner_main(config: &TunnelConfig) -> Result<(), Error> { let https_db = ctx.db.clone(); let https_thread: NonDetachingJoinHandle<()> = tokio::spawn(async move { let mut sub = https_db.subscribe("/webserver".parse().unwrap()).await; - while sub.recv().await.is_some() { + while { while let Err(e) = async { let webserver = https_db.peek().await.into_webserver(); if webserver.as_enabled().de()? { @@ -96,7 +96,8 @@ async fn inner_main(config: &TunnelConfig) -> Result<(), Error> { tracing::debug!("{e:?}"); tokio::time::sleep(Duration::from_secs(5)).await; } - } + sub.recv().await.is_some() + } {} }) .into(); diff --git a/core/startos/src/tunnel/api.rs b/core/startos/src/tunnel/api.rs index de9896932..238143fc0 100644 --- a/core/startos/src/tunnel/api.rs +++ b/core/startos/src/tunnel/api.rs @@ -11,7 +11,7 @@ use crate::net::gateway::{IdFilter, InterfaceFilter}; use crate::prelude::*; use crate::tunnel::context::TunnelContext; use crate::tunnel::db::GatewayPort; -use crate::tunnel::wg::{ClientConfig, WgConfig, WgSubnetClients, WgSubnetConfig}; +use crate::tunnel::wg::{WgConfig, WgSubnetClients, WgSubnetConfig}; use crate::util::serde::{HandlerExtSerde, display_serializable}; pub fn tunnel_api() -> ParentHandler { @@ -30,6 +30,10 @@ pub fn tunnel_api() -> ParentHandler { "subnet", subnet_api::().with_about("Add, remove, or modify subnets"), ) + .subcommand( + "device", + device_api::().with_about("Add, remove, or list devices in subnets"), + ) .subcommand( "port-forward", ParentHandler::::new() @@ -78,28 +82,29 @@ pub fn subnet_api() -> ParentHandler { .with_about("Remove a subnet") .with_call_remote::(), ) +} + +pub fn device_api() -> ParentHandler { + ParentHandler::new() .subcommand( - "add-device", + "add", from_fn_async(add_device) .with_metadata("sync_db", Value::Bool(true)) - .with_inherited(|a, _| a) .no_display() .with_about("Add a device to a subnet") .with_call_remote::(), ) .subcommand( - "remove-device", + "remove", from_fn_async(remove_device) .with_metadata("sync_db", Value::Bool(true)) - .with_inherited(|a, _| a) .no_display() .with_about("Remove a device from a subnet") .with_call_remote::(), ) .subcommand( - "list-devices", + "list", from_fn_async(list_devices) - .with_inherited(|a, _| a) .with_display_serializable() .with_custom_display_fn(|HandlerArgs { params, .. }, res| { use prettytable::*; @@ -124,18 +129,7 @@ pub fn subnet_api() -> ParentHandler { .subcommand( "show-config", from_fn_async(show_config) - .with_inherited(|a, _| a) - .with_display_serializable() - .with_custom_display_fn(|HandlerArgs { params, .. }, res| { - if let Some(format) = params.format { - return display_serializable(format, res); - } - - println!("{}", res); - - Ok(()) - }) - .with_about("Show the WireGuard configuration for a subnet") + .with_about("Show the WireGuard configuration for a device") .with_call_remote::(), ) } @@ -195,14 +189,14 @@ pub async fn remove_subnet( #[derive(Deserialize, Serialize, Parser)] #[serde(rename_all = "camelCase")] pub struct AddDeviceParams { + subnet: Ipv4Net, name: InternedString, ip: Option, } pub async fn add_device( ctx: TunnelContext, - AddDeviceParams { name, ip }: AddDeviceParams, - SubnetParams { subnet }: SubnetParams, + AddDeviceParams { subnet, name, ip }: AddDeviceParams, ) -> Result<(), Error> { let config = WgConfig::generate(name); let server = ctx @@ -254,13 +248,13 @@ pub async fn add_device( #[derive(Deserialize, Serialize, Parser)] #[serde(rename_all = "camelCase")] pub struct RemoveDeviceParams { + subnet: Ipv4Net, device: Ipv4Addr, } pub async fn remove_device( ctx: TunnelContext, - RemoveDeviceParams { device }: RemoveDeviceParams, - SubnetParams { subnet }: SubnetParams, + RemoveDeviceParams { subnet, device }: RemoveDeviceParams, ) -> Result<(), Error> { let server = ctx .db @@ -279,10 +273,15 @@ pub async fn remove_device( server.sync().await } +#[derive(Deserialize, Serialize, Parser)] +#[serde(rename_all = "camelCase")] +pub struct ListDevicesParams { + subnet: Ipv4Net, +} + pub async fn list_devices( ctx: TunnelContext, - _: Empty, - SubnetParams { subnet }: SubnetParams, + ListDevicesParams { subnet }: ListDevicesParams, ) -> Result { ctx.db .peek() @@ -297,7 +296,8 @@ pub async fn list_devices( #[derive(Deserialize, Serialize, Parser)] #[serde(rename_all = "camelCase")] pub struct ShowConfigParams { - device: Ipv4Addr, + subnet: Ipv4Net, + ip: Ipv4Addr, wan_addr: Option, #[serde(rename = "__ConnectInfo_local_addr")] #[arg(skip)] @@ -307,12 +307,12 @@ pub struct ShowConfigParams { pub async fn show_config( ctx: TunnelContext, ShowConfigParams { - device, + subnet, + ip, wan_addr, local_addr, }: ShowConfigParams, - SubnetParams { subnet }: SubnetParams, -) -> Result { +) -> Result { let peek = ctx.db.peek().await; let wg = peek.as_wg(); let client = wg @@ -320,8 +320,8 @@ pub async fn show_config( .as_idx(&subnet) .or_not_found(&subnet)? .as_clients() - .as_idx(&device) - .or_not_found(&device)? + .as_idx(&ip) + .or_not_found(&ip)? .de()?; let wan_addr = if let Some(wan_addr) = wan_addr.or(local_addr.map(|a| a.ip())).filter(|ip| { !ip.is_loopback() @@ -348,11 +348,13 @@ pub async fn show_config( .or_not_found("a public IP address")? .addr() }; - Ok(client.client_config( - device, - wg.as_key().de()?.verifying_key(), - (wan_addr, wg.as_port().de()?).into(), - )) + Ok(client + .client_config( + ip, + wg.as_key().de()?.verifying_key(), + (wan_addr, wg.as_port().de()?).into(), + ) + .to_string()) } #[derive(Deserialize, Serialize, Parser)] diff --git a/core/startos/src/tunnel/auth.rs b/core/startos/src/tunnel/auth.rs index 8818bfbec..123ea9982 100644 --- a/core/startos/src/tunnel/auth.rs +++ b/core/startos/src/tunnel/auth.rs @@ -120,6 +120,20 @@ pub struct SignerInfo { pub fn auth_api() -> ParentHandler { ParentHandler::new() + .subcommand( + "login", + from_fn_async(crate::auth::login_impl::) + .with_metadata("login", Value::Bool(true)) + .no_cli(), + ) + .subcommand( + "logout", + from_fn_async(crate::auth::logout::) + .with_metadata("get_session", Value::Bool(true)) + .no_display() + .with_about("Log out of current auth session") + .with_call_remote::(), + ) .subcommand("set-password", from_fn_async(set_password_rpc).no_cli()) .subcommand( "set-password", diff --git a/core/startos/src/tunnel/context.rs b/core/startos/src/tunnel/context.rs index 217c74b8c..51b32774e 100644 --- a/core/startos/src/tunnel/context.rs +++ b/core/startos/src/tunnel/context.rs @@ -23,7 +23,7 @@ use url::Url; use crate::auth::Sessions; use crate::context::config::ContextConfig; use crate::context::{CliContext, RpcContext}; -use crate::db::model::public::NetworkInterfaceInfo; +use crate::db::model::public::{NetworkInterfaceInfo, NetworkInterfaceType}; use crate::else_empty_dir; use crate::middleware::auth::{Auth, AuthContext}; use crate::middleware::cors::Cors; @@ -140,6 +140,11 @@ impl TunnelContext { for iface in net_iface.peek(|i| { i.iter() + .filter(|(_, info)| { + info.ip_info.as_ref().map_or(false, |i| { + i.device_type != Some(NetworkInterfaceType::Loopback) + }) + }) .map(|(name, _)| name) .filter(|id| id.as_str() != WIREGUARD_INTERFACE_NAME) .cloned() @@ -206,6 +211,12 @@ impl AsRef for TunnelContext { } } +impl AsRef>> for TunnelContext { + fn as_ref(&self) -> &OpenAuthedContinuations> { + &self.open_authed_continuations + } +} + impl Context for TunnelContext {} impl Deref for TunnelContext { type Target = TunnelContextSeed; diff --git a/core/startos/src/tunnel/db.rs b/core/startos/src/tunnel/db.rs index 91f4d54c9..d11a3501d 100644 --- a/core/startos/src/tunnel/db.rs +++ b/core/startos/src/tunnel/db.rs @@ -1,7 +1,9 @@ use std::collections::BTreeMap; use std::net::SocketAddrV4; use std::path::PathBuf; +use std::time::Duration; +use axum::extract::ws; use clap::Parser; use clap::builder::ValueParserFactory; use imbl::{HashMap, OrdMap}; @@ -20,11 +22,13 @@ use crate::auth::Sessions; use crate::context::CliContext; use crate::db::model::public::NetworkInterfaceInfo; use crate::prelude::*; +use crate::rpc_continuations::{Guid, RpcContinuation}; use crate::sign::AnyVerifyingKey; use crate::tunnel::auth::SignerInfo; use crate::tunnel::context::TunnelContext; use crate::tunnel::web::WebserverInfo; use crate::tunnel::wg::WgServer; +use crate::util::net::WebSocketExt; use crate::util::serde::{HandlerExtSerde, apply_expr, deserialize_from_str, serialize_display}; #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -112,6 +116,12 @@ pub fn db_api() -> ParentHandler { .with_metadata("admin", Value::Bool(true)) .no_cli(), ) + .subcommand( + "subscribe", + from_fn_async(subscribe) + .with_metadata("get_session", Value::Bool(true)) + .no_cli(), + ) .subcommand( "apply", from_fn_async(cli_apply) @@ -260,3 +270,75 @@ pub async fn apply(ctx: TunnelContext, ApplyParams { expr, .. }: ApplyParams) -> .await .result } + +#[derive(Deserialize, Serialize, TS)] +#[serde(rename_all = "camelCase")] +pub struct SubscribeParams { + #[ts(type = "string | null")] + pointer: Option, + #[ts(skip)] + #[serde(rename = "__Auth_session")] + session: Option, +} + +#[derive(Deserialize, Serialize, TS)] +#[serde(rename_all = "camelCase")] +pub struct SubscribeRes { + #[ts(type = "{ id: number; value: unknown }")] + pub dump: Dump, + pub guid: Guid, +} + +pub async fn subscribe( + ctx: TunnelContext, + SubscribeParams { pointer, session }: SubscribeParams, +) -> Result { + let (dump, mut sub) = ctx + .db + .dump_and_sub(pointer.unwrap_or_else(|| ROOT.to_owned())) + .await; + let guid = Guid::new(); + ctx.rpc_continuations + .add( + guid.clone(), + RpcContinuation::ws_authed( + &ctx, + session, + |mut ws| async move { + if let Err(e) = async { + loop { + tokio::select! { + rev = sub.recv() => { + if let Some(rev) = rev { + ws.send(ws::Message::Text( + serde_json::to_string(&rev) + .with_kind(ErrorKind::Serialization)? + .into(), + )) + .await + .with_kind(ErrorKind::Network)?; + } else { + return ws.normal_close("complete").await; + } + } + msg = ws.recv() => { + if msg.transpose().with_kind(ErrorKind::Network)?.is_none() { + return Ok(()) + } + } + } + } + } + .await + { + tracing::error!("Error in db websocket: {e}"); + tracing::debug!("{e:?}"); + } + }, + Duration::from_secs(30), + ), + ) + .await; + + Ok(SubscribeRes { dump, guid }) +}