Feature/start tunnel (#3037)

* fix live-build resolv.conf

* improved debuggability

* wip: start-tunnel

* fixes for trixie and tor

* non-free-firmware on trixie

* wip

* web server WIP

* wip: tls refactor

* FE patchdb, mocks, and most endpoints

* fix editing records and patch mocks

* refactor complete

* finish api

* build and formatter update

* minor change toi viewing addresses and fix build

* fixes

* more providers

* endpoint for getting config

* fix tests

* api fixes

* wip: separate port forward controller into parts

* simplify iptables rules

* bump sdk

* misc fixes

* predict next subnet and ip, use wan ips, and form validation

* refactor: break big components apart and address todos (#3043)

* refactor: break big components apart and address todos

* starttunnel readme, fix pf mocks, fix adding tor domain in startos

---------

Co-authored-by: Matt Hill <mattnine@protonmail.com>

* better tui

* tui tweaks

* fix: address comments

* better regex for subnet

* fixes

* better validation

* handle rpc errors

* build fixes

* fix: address comments (#3044)

* fix: address comments

* fix unread notification mocks

* fix row click for notification

---------

Co-authored-by: Matt Hill <mattnine@protonmail.com>

* fix raspi build

* fix build

* fix build

* fix build

* fix build

* try to fix build

* fix tests

* fix tests

* fix rsync tests

* delete useless effectful test

---------

Co-authored-by: Matt Hill <mattnine@protonmail.com>
Co-authored-by: Alex Inkin <alexander@inkin.ru>
This commit is contained in:
Aiden McClelland
2025-11-07 03:12:05 -07:00
committed by GitHub
parent 1ea525feaa
commit 68f401bfa3
229 changed files with 17255 additions and 10553 deletions

View File

@@ -1,49 +1,61 @@
use std::net::Ipv4Addr;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use clap::Parser;
use imbl_value::InternedString;
use ipnet::Ipv4Net;
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerExt, ParentHandler};
use rpc_toolkit::{Context, Empty, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use crate::context::CliContext;
use crate::prelude::*;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::wg::WgSubnetConfig;
use crate::tunnel::wg::{WgConfig, WgSubnetClients, WgSubnetConfig};
use crate::util::serde::{HandlerExtSerde, display_serializable};
pub fn tunnel_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand("web", super::web::web_api::<C>())
.subcommand(
"db",
super::db::db_api::<C>()
.with_about("Commands to interact with the db i.e. dump and apply"),
)
.subcommand(
"auth",
super::auth::auth_api::<C>().with_about("Add or remove authorized clients"),
)
.subcommand(
"subnet",
subnet_api::<C>().with_about("Add, remove, or modify subnets"),
)
// .subcommand(
// "port-forward",
// ParentHandler::<C>::new()
// .subcommand(
// "add",
// from_fn_async(add_forward)
// .with_metadata("sync_db", Value::Bool(true))
// .no_display()
// .with_about("Add a new port forward")
// .with_call_remote::<CliContext>(),
// )
// .subcommand(
// "remove",
// from_fn_async(remove_forward)
// .with_metadata("sync_db", Value::Bool(true))
// .no_display()
// .with_about("Remove a port forward")
// .with_call_remote::<CliContext>(),
// ),
// )
.subcommand(
"device",
device_api::<C>().with_about("Add, remove, or list devices in subnets"),
)
.subcommand(
"port-forward",
ParentHandler::<C>::new()
.subcommand(
"add",
from_fn_async(add_forward)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Add a new port forward")
.with_call_remote::<CliContext>(),
)
.subcommand(
"remove",
from_fn_async(remove_forward)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Remove a port forward")
.with_call_remote::<CliContext>(),
),
)
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct SubnetParams {
subnet: Ipv4Net,
}
@@ -68,44 +80,89 @@ pub fn subnet_api<C: Context>() -> ParentHandler<C, SubnetParams> {
.with_about("Remove a subnet")
.with_call_remote::<CliContext>(),
)
// .subcommand(
// "set-default-forward-target",
// from_fn_async(set_default_forward_target)
// .with_metadata("sync_db", Value::Bool(true))
// .no_display()
// .with_about("Set the default target for port forwarding")
// .with_call_remote::<CliContext>(),
// )
// .subcommand(
// "add-device",
// from_fn_async(add_device)
// .with_metadata("sync_db", Value::Bool(true))
// .no_display()
// .with_about("Add a device to a subnet")
// .with_call_remote::<CliContext>(),
// )
// .subcommand(
// "remove-device",
// from_fn_async(remove_device)
// .with_metadata("sync_db", Value::Bool(true))
// .no_display()
// .with_about("Remove a device from a subnet")
// .with_call_remote::<CliContext>(),
// )
}
pub fn device_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand(
"add",
from_fn_async(add_device)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Add a device to a subnet")
.with_call_remote::<CliContext>(),
)
.subcommand(
"remove",
from_fn_async(remove_device)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Remove a device from a subnet")
.with_call_remote::<CliContext>(),
)
.subcommand(
"list",
from_fn_async(list_devices)
.with_display_serializable()
.with_custom_display_fn(|HandlerArgs { params, .. }, res| {
use prettytable::*;
if let Some(format) = params.format {
return display_serializable(format, res);
}
let mut table = Table::new();
table.add_row(row![bc => "NAME", "IP", "PUBLIC KEY"]);
for (ip, config) in res.clients.0 {
table.add_row(row![config.name, ip, config.key.verifying_key()]);
}
table.print_tty(false)?;
Ok(())
})
.with_about("List devices in a subnet")
.with_call_remote::<CliContext>(),
)
.subcommand(
"show-config",
from_fn_async(show_config)
.with_about("Show the WireGuard configuration for a device")
.with_call_remote::<CliContext>(),
)
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct AddSubnetParams {
name: InternedString,
}
pub async fn add_subnet(
ctx: TunnelContext,
_: Empty,
SubnetParams { subnet }: SubnetParams,
AddSubnetParams { name }: AddSubnetParams,
SubnetParams { mut subnet }: SubnetParams,
) -> Result<(), Error> {
if subnet.prefix_len() > 24 {
return Err(Error::new(
eyre!("invalid subnet"),
ErrorKind::InvalidRequest,
));
}
let addr = subnet
.hosts()
.next()
.ok_or_else(|| Error::new(eyre!("invalid subnet"), ErrorKind::InvalidRequest))?;
subnet = Ipv4Net::new_assert(addr, subnet.prefix_len());
let server = ctx
.db
.mutate(|db| {
let map = db.as_wg_mut().as_subnets_mut();
if !map.contains_key(&subnet)? {
map.insert(&subnet, &WgSubnetConfig::new())?;
}
map.upsert(&subnet, || {
Ok(WgSubnetConfig::new(InternedString::default()))
})?
.as_name_mut()
.ser(&name)?;
db.as_wg().de()
})
.await
@@ -128,3 +185,221 @@ pub async fn remove_subnet(
.result?;
server.sync().await
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct AddDeviceParams {
subnet: Ipv4Net,
name: InternedString,
ip: Option<Ipv4Addr>,
}
pub async fn add_device(
ctx: TunnelContext,
AddDeviceParams { subnet, name, ip }: AddDeviceParams,
) -> Result<(), Error> {
let server = ctx
.db
.mutate(|db| {
db.as_wg_mut()
.as_subnets_mut()
.as_idx_mut(&subnet)
.or_not_found(&subnet)?
.as_clients_mut()
.mutate(|WgSubnetClients(clients)| {
let ip = if let Some(ip) = ip {
ip
} else {
subnet
.hosts()
.find(|ip| !clients.contains_key(ip) && *ip != subnet.addr())
.ok_or_else(|| {
Error::new(
eyre!("no available ips in subnet"),
ErrorKind::InvalidRequest,
)
})?
};
if ip.octets()[3] == 0 || ip.octets()[3] == 255 {
return Err(Error::new(eyre!("invalid ip"), ErrorKind::InvalidRequest));
}
if ip == subnet.addr() {
return Err(Error::new(eyre!("invalid ip"), ErrorKind::InvalidRequest));
}
if !subnet.contains(&ip) {
return Err(Error::new(
eyre!("ip not in subnet"),
ErrorKind::InvalidRequest,
));
}
let client = clients
.entry(ip)
.or_insert_with(|| WgConfig::generate(name.clone()));
client.name = name;
Ok(())
})?;
db.as_wg().de()
})
.await
.result?;
server.sync().await
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct RemoveDeviceParams {
subnet: Ipv4Net,
ip: Ipv4Addr,
}
pub async fn remove_device(
ctx: TunnelContext,
RemoveDeviceParams { subnet, ip }: RemoveDeviceParams,
) -> Result<(), Error> {
let server = ctx
.db
.mutate(|db| {
db.as_wg_mut()
.as_subnets_mut()
.as_idx_mut(&subnet)
.or_not_found(&subnet)?
.as_clients_mut()
.remove(&ip)?
.or_not_found(&ip)?;
db.as_wg().de()
})
.await
.result?;
server.sync().await
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct ListDevicesParams {
subnet: Ipv4Net,
}
pub async fn list_devices(
ctx: TunnelContext,
ListDevicesParams { subnet }: ListDevicesParams,
) -> Result<WgSubnetConfig, Error> {
ctx.db
.peek()
.await
.as_wg()
.as_subnets()
.as_idx(&subnet)
.or_not_found(&subnet)?
.de()
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct ShowConfigParams {
subnet: Ipv4Net,
ip: Ipv4Addr,
wan_addr: Option<IpAddr>,
#[serde(rename = "__ConnectInfo_local_addr")]
#[arg(skip)]
local_addr: Option<SocketAddr>,
}
pub async fn show_config(
ctx: TunnelContext,
ShowConfigParams {
subnet,
ip,
wan_addr,
local_addr,
}: ShowConfigParams,
) -> Result<String, Error> {
let peek = ctx.db.peek().await;
let wg = peek.as_wg();
let client = wg
.as_subnets()
.as_idx(&subnet)
.or_not_found(&subnet)?
.as_clients()
.as_idx(&ip)
.or_not_found(&ip)?
.de()?;
let wan_addr = if let Some(wan_addr) = wan_addr.or(local_addr.map(|a| a.ip())).filter(|ip| {
!ip.is_loopback()
&& !match ip {
IpAddr::V4(ipv4) => ipv4.is_private() || ipv4.is_link_local(),
IpAddr::V6(ipv6) => ipv6.is_unique_local() || ipv6.is_unicast_link_local(),
}
}) {
wan_addr
} else if let Some(webserver) = peek.as_webserver().as_listen().de()? {
webserver.ip()
} else {
ctx.net_iface
.peek(|i| {
i.iter().find_map(|(_, info)| {
info.ip_info
.as_ref()
.filter(|_| info.public())
.iter()
.find_map(|info| info.subnets.iter().next())
.copied()
})
})
.or_not_found("a public IP address")?
.addr()
};
Ok(client
.client_config(
ip,
wg.as_key().de()?.verifying_key(),
(wan_addr, wg.as_port().de()?).into(),
)
.to_string())
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct AddPortForwardParams {
source: SocketAddrV4,
target: SocketAddrV4,
}
pub async fn add_forward(
ctx: TunnelContext,
AddPortForwardParams { source, target }: AddPortForwardParams,
) -> Result<(), Error> {
let rc = ctx.forward.add_forward(source, target).await?;
ctx.active_forwards.mutate(|m| {
m.insert(source, rc);
});
ctx.db
.mutate(|db| db.as_port_forwards_mut().insert(&source, &target))
.await
.result?;
Ok(())
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct RemovePortForwardParams {
source: SocketAddrV4,
}
pub async fn remove_forward(
ctx: TunnelContext,
RemovePortForwardParams { source, .. }: RemovePortForwardParams,
) -> Result<(), Error> {
ctx.db
.mutate(|db| db.as_port_forwards_mut().remove(&source))
.await
.result?;
if let Some(rc) = ctx.active_forwards.mutate(|m| m.remove(&source)) {
drop(rc);
ctx.forward.gc().await?;
}
Ok(())
}

View File

@@ -0,0 +1,323 @@
use clap::Parser;
use imbl::HashMap;
use imbl_value::InternedString;
use itertools::Itertools;
use patch_db::HasModel;
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use crate::auth::{Sessions, check_password};
use crate::context::CliContext;
use crate::middleware::auth::AuthContext;
use crate::middleware::signature::SignatureAuthContext;
use crate::prelude::*;
use crate::rpc_continuations::OpenAuthedContinuations;
use crate::sign::AnyVerifyingKey;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::db::TunnelDatabase;
use crate::util::serde::{HandlerExtSerde, display_serializable};
use crate::util::sync::SyncMutex;
impl SignatureAuthContext for TunnelContext {
type Database = TunnelDatabase;
type AdditionalMetadata = ();
type CheckPubkeyRes = ();
fn db(&self) -> &TypedPatchDb<Self::Database> {
&self.db
}
async fn sig_context(
&self,
) -> impl IntoIterator<Item = Result<impl AsRef<str> + Send, Error>> + Send {
let peek = self.db().peek().await;
peek.as_webserver()
.as_listen()
.de()
.map(|a| a.as_ref().map(InternedString::from_display))
.transpose()
.into_iter()
.chain(
std::iter::once_with(move || {
peek.as_webserver()
.as_certificate()
.de()
.ok()
.flatten()
.and_then(|cert_data| cert_data.cert.0.first().cloned())
.and_then(|cert| cert.subject_alt_names())
.into_iter()
.flatten()
.filter_map(|san| {
san.dnsname().map(InternedString::from).or_else(|| {
san.ipaddress().and_then(|ip_bytes| {
let ip: std::net::IpAddr = match ip_bytes.len() {
4 => std::net::IpAddr::V4(std::net::Ipv4Addr::from(
<[u8; 4]>::try_from(ip_bytes).ok()?,
)),
16 => std::net::IpAddr::V6(std::net::Ipv6Addr::from(
<[u8; 16]>::try_from(ip_bytes).ok()?,
)),
_ => return None,
};
Some(InternedString::from_display(&ip))
})
})
})
.map(Ok)
.collect::<Vec<_>>()
})
.flatten(),
)
}
fn check_pubkey(
db: &Model<Self::Database>,
pubkey: Option<&crate::sign::AnyVerifyingKey>,
_: Self::AdditionalMetadata,
) -> Result<Self::CheckPubkeyRes, Error> {
if let Some(pubkey) = pubkey {
if db.as_auth_pubkeys().de()?.contains_key(pubkey) {
return Ok(());
}
}
Err(Error::new(
eyre!("Key is not authorized"),
ErrorKind::IncorrectPassword,
))
}
async fn post_auth_hook(
&self,
_: Self::CheckPubkeyRes,
_: &rpc_toolkit::RpcRequest,
) -> Result<(), Error> {
Ok(())
}
}
impl AuthContext for TunnelContext {
const LOCAL_AUTH_COOKIE_PATH: &str = "/run/start-tunnel/rpc.authcookie";
const LOCAL_AUTH_COOKIE_OWNERSHIP: &str = "root:root";
fn access_sessions(db: &mut Model<Self::Database>) -> &mut Model<crate::auth::Sessions> {
db.as_sessions_mut()
}
fn ephemeral_sessions(&self) -> &SyncMutex<Sessions> {
&self.ephemeral_sessions
}
fn open_authed_continuations(&self) -> &OpenAuthedContinuations<Option<InternedString>> {
&self.open_authed_continuations
}
fn check_password(db: &Model<Self::Database>, password: &str) -> Result<(), Error> {
check_password(&db.as_password().de()?.unwrap_or_default(), password)
}
}
#[derive(Clone, Debug, Deserialize, Serialize, HasModel, TS, Parser)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
#[ts(export)]
pub struct SignerInfo {
pub name: InternedString,
}
pub fn auth_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand(
"login",
from_fn_async(crate::auth::login_impl::<TunnelContext>)
.with_metadata("login", Value::Bool(true))
.no_cli(),
)
.subcommand(
"logout",
from_fn_async(crate::auth::logout::<TunnelContext>)
.with_metadata("get_session", Value::Bool(true))
.no_display()
.with_about("Log out of current auth session")
.with_call_remote::<CliContext>(),
)
.subcommand("set-password", from_fn_async(set_password_rpc).no_cli())
.subcommand(
"set-password",
from_fn_async(set_password_cli)
.with_about("Set user interface password")
.no_display(),
)
.subcommand(
"reset-password",
from_fn_async(reset_password)
.with_about("Reset user interface password")
.no_display(),
)
.subcommand(
"key",
ParentHandler::<C>::new()
.subcommand(
"add",
from_fn_async(add_key)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Add a new authorized key")
.with_call_remote::<CliContext>(),
)
.subcommand(
"remove",
from_fn_async(remove_key)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Remove an authorized key")
.with_call_remote::<CliContext>(),
)
.subcommand(
"list",
from_fn_async(list_keys)
.with_metadata("sync_db", Value::Bool(true))
.with_display_serializable()
.with_custom_display_fn(|HandlerArgs { params, .. }, res| {
use prettytable::*;
if let Some(format) = params.format {
return display_serializable(format, res);
}
let mut table = Table::new();
table.add_row(row![bc => "NAME", "KEY"]);
for (key, info) in res {
table.add_row(row![info.name, key]);
}
table.print_tty(false)?;
Ok(())
})
.with_about("List authorized keys")
.with_call_remote::<CliContext>(),
),
)
}
#[derive(Debug, Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct AddKeyParams {
pub name: InternedString,
pub key: AnyVerifyingKey,
}
pub async fn add_key(
ctx: TunnelContext,
AddKeyParams { name, key }: AddKeyParams,
) -> Result<(), Error> {
ctx.db
.mutate(|db| {
db.as_auth_pubkeys_mut().mutate(|auth_pubkeys| {
auth_pubkeys.insert(key, SignerInfo { name });
Ok(())
})
})
.await
.result
}
#[derive(Debug, Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct RemoveKeyParams {
pub key: AnyVerifyingKey,
}
pub async fn remove_key(
ctx: TunnelContext,
RemoveKeyParams { key }: RemoveKeyParams,
) -> Result<(), Error> {
ctx.db
.mutate(|db| {
db.as_auth_pubkeys_mut()
.mutate(|auth_pubkeys| Ok(auth_pubkeys.remove(&key)))
})
.await
.result?;
Ok(())
}
pub async fn list_keys(ctx: TunnelContext) -> Result<HashMap<AnyVerifyingKey, SignerInfo>, Error> {
ctx.db.peek().await.into_auth_pubkeys().de()
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SetPasswordParams {
pub password: String,
}
pub async fn set_password_rpc(
ctx: TunnelContext,
SetPasswordParams { password }: SetPasswordParams,
) -> Result<(), Error> {
let pwhash = argon2::hash_encoded(
password.as_bytes(),
&rand::random::<[u8; 16]>(),
&argon2::Config::rfc9106_low_mem(),
)
.with_kind(ErrorKind::PasswordHashGeneration)?;
ctx.db
.mutate(|db| db.as_password_mut().ser(&Some(pwhash)))
.await
.result?;
Ok(())
}
pub async fn set_password_cli(
HandlerArgs {
context,
parent_method,
method,
..
}: HandlerArgs<CliContext>,
) -> Result<(), Error> {
let password = rpassword::prompt_password("New Password: ")?;
let confirm = rpassword::prompt_password("Confirm Password: ")?;
if password != confirm {
return Err(Error::new(
eyre!("Passwords do not match"),
ErrorKind::InvalidRequest,
));
}
context
.call_remote::<TunnelContext>(
&parent_method.iter().chain(method.iter()).join("."),
to_value(&SetPasswordParams { password })?,
)
.await?;
println!("Password set successfully");
Ok(())
}
pub async fn reset_password(
HandlerArgs {
context,
parent_method,
method,
..
}: HandlerArgs<CliContext>,
) -> Result<(), Error> {
println!("Generating a random password...");
let params = SetPasswordParams {
password: base32::encode(
base32::Alphabet::Rfc4648Lower { padding: false },
&rand::random::<[u8; 16]>(),
),
};
context
.call_remote::<TunnelContext>(
&parent_method.iter().chain(method.iter()).join("."),
to_value(&params)?,
)
.await?;
println!("Your new password is:");
println!("{}", params.password);
Ok(())
}

View File

@@ -1,3 +1,5 @@
# StartTunnel config for {name}
[Interface]
Address = {addr}/24
PrivateKey = {privkey}
@@ -5,6 +7,6 @@ PrivateKey = {privkey}
[Peer]
PublicKey = {server_pubkey}
PresharedKey = {psk}
AllowedIPs = 0.0.0.0/0, ::/0
AllowedIPs = 0.0.0.0/0,::/0
Endpoint = {server_addr}
PersistentKeepalive = 25

View File

@@ -1,31 +1,44 @@
use std::collections::BTreeSet;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::collections::BTreeMap;
use std::net::{IpAddr, SocketAddr, SocketAddrV4};
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use clap::Parser;
use cookie::{Cookie, Expiration, SameSite};
use http::HeaderMap;
use imbl::OrdMap;
use imbl_value::InternedString;
use include_dir::Dir;
use models::GatewayId;
use patch_db::PatchDb;
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{CallRemote, Context, Empty};
use rpc_toolkit::{CallRemote, Context, Empty, ParentHandler};
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use tokio::sync::broadcast::Sender;
use tracing::instrument;
use url::Url;
use crate::auth::{Sessions, check_password};
use crate::context::CliContext;
use crate::auth::Sessions;
use crate::context::config::ContextConfig;
use crate::middleware::auth::AuthContext;
use crate::middleware::signature::SignatureAuthContext;
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::{NetworkInterfaceInfo, NetworkInterfaceType};
use crate::else_empty_dir;
use crate::middleware::auth::{Auth, AuthContext};
use crate::middleware::cors::Cors;
use crate::net::forward::PortForwardController;
use crate::net::gateway::NetworkInterfaceWatcher;
use crate::net::static_server::UiContext;
use crate::prelude::*;
use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations};
use crate::tunnel::TUNNEL_DEFAULT_PORT;
use crate::tunnel::TUNNEL_DEFAULT_LISTEN;
use crate::tunnel::api::tunnel_api;
use crate::tunnel::db::TunnelDatabase;
use crate::util::sync::SyncMutex;
use crate::tunnel::wg::WIREGUARD_INTERFACE_NAME;
use crate::util::Invoke;
use crate::util::collections::OrdMapIterMut;
use crate::util::io::read_file_to_string;
use crate::util::sync::{SyncMutex, Watch};
#[derive(Debug, Clone, Default, Deserialize, Serialize, Parser)]
#[serde(rename_all = "kebab-case")]
@@ -59,14 +72,14 @@ impl TunnelConfig {
pub struct TunnelContextSeed {
pub listen: SocketAddr,
pub addrs: BTreeSet<IpAddr>,
pub db: TypedPatchDb<TunnelDatabase>,
pub datadir: PathBuf,
pub rpc_continuations: RpcContinuations,
pub open_authed_continuations: OpenAuthedContinuations<Option<InternedString>>,
pub ephemeral_sessions: SyncMutex<Sessions>,
pub net_iface: NetworkInterfaceWatcher,
pub net_iface: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
pub forward: PortForwardController,
pub active_forwards: SyncMutex<BTreeMap<SocketAddrV4, Arc<()>>>,
pub shutdown: Sender<()>,
}
@@ -75,6 +88,7 @@ pub struct TunnelContext(Arc<TunnelContextSeed>);
impl TunnelContext {
#[instrument(skip_all)]
pub async fn init(config: &TunnelConfig) -> Result<Self, Error> {
Self::init_auth_cookie().await?;
let (shutdown, _) = tokio::sync::broadcast::channel(1);
let datadir = config
.datadir
@@ -90,19 +104,83 @@ impl TunnelContext {
|| async { Ok(Default::default()) },
)
.await?;
let listen = config.tunnel_listen.unwrap_or(SocketAddr::new(
Ipv6Addr::UNSPECIFIED.into(),
TUNNEL_DEFAULT_PORT,
));
let net_iface = NetworkInterfaceWatcher::new(async { OrdMap::new() }, []);
let forward = PortForwardController::new(net_iface.subscribe());
let listen = config.tunnel_listen.unwrap_or(TUNNEL_DEFAULT_LISTEN);
let ip_info = crate::net::utils::load_ip_info().await?;
let net_iface = db
.mutate(|db| {
db.as_gateways_mut().mutate(|g| {
for (_, v) in OrdMapIterMut::from(&mut *g) {
v.ip_info = None;
}
for (id, info) in ip_info {
if id.as_str() != WIREGUARD_INTERFACE_NAME {
g.entry(id).or_default().ip_info = Some(Arc::new(info));
}
}
Ok(g.clone())
})
})
.await
.result?;
let net_iface = Watch::new(net_iface);
let forward = PortForwardController::new();
Command::new("sysctl")
.arg("-w")
.arg("net.ipv4.ip_forward=1")
.invoke(ErrorKind::Network)
.await?;
for iface in net_iface.peek(|i| {
i.iter()
.filter(|(_, info)| {
info.ip_info.as_ref().map_or(false, |i| {
i.device_type != Some(NetworkInterfaceType::Loopback)
})
})
.map(|(name, _)| name)
.filter(|id| id.as_str() != WIREGUARD_INTERFACE_NAME)
.cloned()
.collect::<Vec<_>>()
}) {
if Command::new("iptables")
.arg("-t")
.arg("nat")
.arg("-C")
.arg("POSTROUTING")
.arg("-o")
.arg(iface.as_str())
.arg("-j")
.arg("MASQUERADE")
.invoke(ErrorKind::Network)
.await
.is_err()
{
tracing::info!("Adding masquerade rule for interface {}", iface);
Command::new("iptables")
.arg("-t")
.arg("nat")
.arg("-A")
.arg("POSTROUTING")
.arg("-o")
.arg(iface.as_str())
.arg("-j")
.arg("MASQUERADE")
.invoke(ErrorKind::Network)
.await
.log_err();
}
}
let peek = db.peek().await;
peek.as_wg().de()?.sync().await?;
let mut active_forwards = BTreeMap::new();
for (from, to) in peek.as_port_forwards().de()?.0 {
active_forwards.insert(from, forward.add_forward(from, to).await?);
}
Ok(Self(Arc::new(TunnelContextSeed {
listen,
addrs: crate::net::utils::all_socket_addrs_for(listen.port())
.await?
.into_iter()
.map(|(_, a)| a.ip())
.collect(),
db,
datadir,
rpc_continuations: RpcContinuations::new(),
@@ -110,6 +188,7 @@ impl TunnelContext {
ephemeral_sessions: SyncMutex::new(Sessions::new()),
net_iface,
forward,
active_forwards: SyncMutex::new(active_forwards),
shutdown,
})))
}
@@ -120,6 +199,12 @@ impl AsRef<RpcContinuations> for TunnelContext {
}
}
impl AsRef<OpenAuthedContinuations<Option<InternedString>>> for TunnelContext {
fn as_ref(&self) -> &OpenAuthedContinuations<Option<InternedString>> {
&self.open_authed_continuations
}
}
impl Context for TunnelContext {}
impl Deref for TunnelContext {
type Target = TunnelContextSeed;
@@ -133,66 +218,6 @@ pub struct TunnelAddrParams {
pub tunnel: IpAddr,
}
impl SignatureAuthContext for TunnelContext {
type Database = TunnelDatabase;
type AdditionalMetadata = ();
type CheckPubkeyRes = ();
fn db(&self) -> &TypedPatchDb<Self::Database> {
&self.db
}
async fn sig_context(
&self,
) -> impl IntoIterator<Item = Result<impl AsRef<str> + Send, Error>> + Send {
self.addrs
.iter()
.filter(|a| !match a {
IpAddr::V4(a) => a.is_loopback() || a.is_unspecified(),
IpAddr::V6(a) => a.is_loopback() || a.is_unspecified(),
})
.map(|a| InternedString::from_display(&a))
.map(Ok)
}
fn check_pubkey(
db: &Model<Self::Database>,
pubkey: Option<&crate::sign::AnyVerifyingKey>,
_: Self::AdditionalMetadata,
) -> Result<Self::CheckPubkeyRes, Error> {
if let Some(pubkey) = pubkey {
if db.as_auth_pubkeys().de()?.contains(pubkey) {
return Ok(());
}
}
Err(Error::new(
eyre!("Developer Key is not authorized"),
ErrorKind::IncorrectPassword,
))
}
async fn post_auth_hook(
&self,
_: Self::CheckPubkeyRes,
_: &rpc_toolkit::RpcRequest,
) -> Result<(), Error> {
Ok(())
}
}
impl AuthContext for TunnelContext {
const LOCAL_AUTH_COOKIE_PATH: &str = "/run/start-tunnel/rpc.authcookie";
const LOCAL_AUTH_COOKIE_OWNERSHIP: &str = "root:root";
fn access_sessions(db: &mut Model<Self::Database>) -> &mut Model<crate::auth::Sessions> {
db.as_sessions_mut()
}
fn ephemeral_sessions(&self) -> &SyncMutex<Sessions> {
&self.ephemeral_sessions
}
fn open_authed_continuations(&self) -> &OpenAuthedContinuations<Option<InternedString>> {
&self.open_authed_continuations
}
fn check_password(db: &Model<Self::Database>, password: &str) -> Result<(), Error> {
check_password(&db.as_password().de()?, password)
}
}
impl CallRemote<TunnelContext> for CliContext {
async fn call_remote(
&self,
@@ -200,25 +225,97 @@ impl CallRemote<TunnelContext> for CliContext {
params: Value,
_: Empty,
) -> Result<Value, RpcError> {
let tunnel_addr = if let Some(addr) = self.tunnel_addr {
addr
let (tunnel_addr, addr_from_config) = if let Some(addr) = self.tunnel_addr {
(addr, true)
} else if let Some(addr) = self.tunnel_listen {
addr
(addr, true)
} else {
(TUNNEL_DEFAULT_LISTEN, false)
};
let local =
if let Ok(local) = read_file_to_string(TunnelContext::LOCAL_AUTH_COOKIE_PATH).await {
self.cookie_store
.lock()
.unwrap()
.insert_raw(
&Cookie::build(("local", local))
.domain(&tunnel_addr.ip().to_string())
.expires(Expiration::Session)
.same_site(SameSite::Strict)
.build(),
&format!("http://{tunnel_addr}").parse()?,
)
.with_kind(crate::ErrorKind::Network)?;
true
} else {
false
};
let (url, sig_ctx) = if local && tunnel_addr.ip().is_loopback() {
(format!("http://{tunnel_addr}/rpc/v0").parse()?, None)
} else if addr_from_config {
(
format!("https://{tunnel_addr}/rpc/v0").parse()?,
Some(InternedString::from_display(&tunnel_addr.ip())),
)
} else {
return Err(Error::new(eyre!("`--tunnel` required"), ErrorKind::InvalidRequest).into());
};
let sig_addr = self.tunnel_listen.unwrap_or(tunnel_addr);
let url = format!("https://{tunnel_addr}").parse()?;
method = method.strip_prefix("tunnel.").unwrap_or(method);
crate::middleware::signature::call_remote(
self,
url,
&InternedString::from_display(&sig_addr.ip()),
HeaderMap::new(),
sig_ctx.as_deref(),
method,
params,
)
.await
}
}
#[derive(Debug, Deserialize, Serialize, Parser)]
pub struct TunnelUrlParams {
pub tunnel: Url,
}
impl CallRemote<TunnelContext, TunnelUrlParams> for RpcContext {
async fn call_remote(
&self,
mut method: &str,
params: Value,
TunnelUrlParams { tunnel }: TunnelUrlParams,
) -> Result<Value, RpcError> {
let url = tunnel.join("rpc/v0")?;
method = method.strip_prefix("tunnel.").unwrap_or(method);
let sig_ctx = url.host_str().map(InternedString::from_display);
crate::middleware::signature::call_remote(
self,
url,
HeaderMap::new(),
sig_ctx.as_deref(),
method,
params,
)
.await
}
}
impl UiContext for TunnelContext {
const UI_DIR: &'static include_dir::Dir<'static> = &else_empty_dir!(
feature = "tunnel" =>
include_dir::include_dir!("$CARGO_MANIFEST_DIR/../../web/dist/static/start-tunnel")
);
fn api() -> ParentHandler<Self> {
tracing::info!("loading tunnel api...");
tunnel_api()
}
fn middleware(server: rpc_toolkit::Server<Self>) -> rpc_toolkit::HttpServer<Self> {
server.middleware(Cors::new()).middleware(Auth::new())
}
}

View File

@@ -1,11 +1,14 @@
use std::collections::{BTreeMap, HashSet};
use std::net::{Ipv4Addr, SocketAddrV4};
use std::collections::BTreeMap;
use std::net::SocketAddrV4;
use std::path::PathBuf;
use std::time::Duration;
use axum::extract::ws;
use clap::Parser;
use imbl::{HashMap, OrdMap};
use imbl_value::InternedString;
use ipnet::Ipv4Net;
use itertools::Itertools;
use models::GatewayId;
use patch_db::Dump;
use patch_db::json_ptr::{JsonPointer, ROOT};
use rpc_toolkit::yajrc::RpcError;
@@ -16,21 +19,48 @@ use ts_rs::TS;
use crate::auth::Sessions;
use crate::context::CliContext;
use crate::db::model::public::NetworkInterfaceInfo;
use crate::prelude::*;
use crate::rpc_continuations::{Guid, RpcContinuation};
use crate::sign::AnyVerifyingKey;
use crate::tunnel::auth::SignerInfo;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::web::WebserverInfo;
use crate::tunnel::wg::WgServer;
use crate::util::net::WebSocketExt;
use crate::util::serde::{HandlerExtSerde, apply_expr};
#[derive(Default, Deserialize, Serialize, HasModel)]
#[derive(Default, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct TunnelDatabase {
pub webserver: WebserverInfo,
pub sessions: Sessions,
pub password: String,
pub auth_pubkeys: HashSet<AnyVerifyingKey>,
pub password: Option<String>,
#[ts(as = "std::collections::HashMap::<AnyVerifyingKey, SignerInfo>")]
pub auth_pubkeys: HashMap<AnyVerifyingKey, SignerInfo>,
#[ts(as = "std::collections::BTreeMap::<AnyVerifyingKey, SignerInfo>")]
pub gateways: OrdMap<GatewayId, NetworkInterfaceInfo>,
pub wg: WgServer,
pub port_forwards: BTreeMap<SocketAddrV4, SocketAddrV4>,
pub port_forwards: PortForwards,
}
#[test]
fn export_bindings_tunnel_db() {
TunnelDatabase::export_all_to("bindings/tunnel").unwrap();
}
#[derive(Clone, Debug, Default, Deserialize, Serialize, TS)]
pub struct PortForwards(pub BTreeMap<SocketAddrV4, SocketAddrV4>);
impl Map for PortForwards {
type Key = SocketAddrV4;
type Value = SocketAddrV4;
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
Self::key_string(key)
}
fn key_string(key: &Self::Key) -> Result<InternedString, Error> {
Ok(InternedString::from_display(key))
}
}
pub fn db_api<C: Context>() -> ParentHandler<C> {
@@ -47,6 +77,12 @@ pub fn db_api<C: Context>() -> ParentHandler<C> {
.with_metadata("admin", Value::Bool(true))
.no_cli(),
)
.subcommand(
"subscribe",
from_fn_async(subscribe)
.with_metadata("get_session", Value::Bool(true))
.no_cli(),
)
.subcommand(
"apply",
from_fn_async(cli_apply)
@@ -195,3 +231,75 @@ pub async fn apply(ctx: TunnelContext, ApplyParams { expr, .. }: ApplyParams) ->
.await
.result
}
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct SubscribeParams {
#[ts(type = "string | null")]
pointer: Option<JsonPointer>,
#[ts(skip)]
#[serde(rename = "__Auth_session")]
session: Option<InternedString>,
}
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct SubscribeRes {
#[ts(type = "{ id: number; value: unknown }")]
pub dump: Dump,
pub guid: Guid,
}
pub async fn subscribe(
ctx: TunnelContext,
SubscribeParams { pointer, session }: SubscribeParams,
) -> Result<SubscribeRes, Error> {
let (dump, mut sub) = ctx
.db
.dump_and_sub(pointer.unwrap_or_else(|| ROOT.to_owned()))
.await;
let guid = Guid::new();
ctx.rpc_continuations
.add(
guid.clone(),
RpcContinuation::ws_authed(
&ctx,
session,
|mut ws| async move {
if let Err(e) = async {
loop {
tokio::select! {
rev = sub.recv() => {
if let Some(rev) = rev {
ws.send(ws::Message::Text(
serde_json::to_string(&rev)
.with_kind(ErrorKind::Serialization)?
.into(),
))
.await
.with_kind(ErrorKind::Network)?;
} else {
return ws.normal_close("complete").await;
}
}
msg = ws.recv() => {
if msg.transpose().with_kind(ErrorKind::Network)?.is_none() {
return Ok(())
}
}
}
}
}
.await
{
tracing::error!("Error in db websocket: {e}");
tracing::debug!("{e:?}");
}
},
Duration::from_secs(30),
),
)
.await;
Ok(SubscribeRes { dump, guid })
}

View File

@@ -1 +0,0 @@
use crate::prelude::*;

View File

@@ -1,82 +1,23 @@
use axum::Router;
use futures::future::ready;
use rpc_toolkit::{Context, HandlerExt, ParentHandler, Server, from_fn_async};
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use crate::context::CliContext;
use crate::middleware::auth::Auth;
use crate::middleware::cors::Cors;
use crate::net::static_server::{bad_request, not_found, server_error};
use crate::net::web_server::{Accept, WebServer};
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use axum::Router;
use crate::net::static_server::ui_router;
use crate::tunnel::context::TunnelContext;
pub mod api;
pub mod auth;
pub mod context;
pub mod db;
pub mod forward;
pub mod web;
pub mod wg;
pub const TUNNEL_DEFAULT_PORT: u16 = 5960;
pub const TUNNEL_DEFAULT_LISTEN: SocketAddr = SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(127, 0, 59, 60),
TUNNEL_DEFAULT_PORT,
));
pub fn tunnel_router(ctx: TunnelContext) -> Router {
use axum::extract as x;
use axum::routing::{any, get};
Router::new()
.route("/rpc/{*path}", {
let ctx = ctx.clone();
any(
Server::new(move || ready(Ok(ctx.clone())), api::tunnel_api())
.middleware(Cors::new())
.middleware(Auth::new())
)
})
.route(
"/ws/rpc/{*path}",
get({
let ctx = ctx.clone();
move |x::Path(path): x::Path<String>,
ws: axum::extract::ws::WebSocketUpgrade| async move {
match Guid::from(&path) {
None => {
tracing::debug!("No Guid Path");
bad_request()
}
Some(guid) => match ctx.rpc_continuations.get_ws_handler(&guid).await {
Some(cont) => ws.on_upgrade(cont),
_ => not_found(),
},
}
}
}),
)
.route(
"/rest/rpc/{*path}",
any({
let ctx = ctx.clone();
move |request: x::Request| async move {
let path = request
.uri()
.path()
.strip_prefix("/rest/rpc/")
.unwrap_or_default();
match Guid::from(&path) {
None => {
tracing::debug!("No Guid Path");
bad_request()
}
Some(guid) => match ctx.rpc_continuations.get_rest_handler(&guid).await {
None => not_found(),
Some(cont) => cont(request).await.unwrap_or_else(server_error),
},
}
}
}),
)
}
impl<A: Accept + Send + Sync + 'static> WebServer<A> {
pub fn serve_tunnel(&mut self, ctx: TunnelContext) {
self.serve_router(tunnel_router(ctx))
}
ui_router(ctx)
}

View File

@@ -0,0 +1,688 @@
use std::collections::VecDeque;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use clap::Parser;
use hickory_client::proto::rr::rdata::cert;
use imbl_value::{InternedString, json};
use itertools::Itertools;
use openssl::pkey::{PKey, Private};
use openssl::x509::X509;
use rpc_toolkit::{
Context, Empty, HandlerArgs, HandlerExt, ParentHandler, from_fn_async, from_fn_async_local,
};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::rustls::crypto::CryptoProvider;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio_rustls::rustls::server::ClientHello;
use ts_rs::TS;
use crate::context::CliContext;
use crate::net::ssl::SANInfo;
use crate::net::tls::TlsHandler;
use crate::net::web_server::Accept;
use crate::prelude::*;
use crate::tunnel::auth::SetPasswordParams;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::db::TunnelDatabase;
use crate::util::serde::{HandlerExtSerde, Pem, display_serializable};
use crate::util::tui::{choose, choose_custom_display, parse_as, prompt, prompt_multiline};
#[derive(Debug, Default, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct WebserverInfo {
pub enabled: bool,
pub listen: Option<SocketAddr>,
pub certificate: Option<TunnelCertData>,
}
#[derive(Debug, Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct TunnelCertData {
pub key: Pem<PKey<Private>>,
pub cert: Pem<Vec<X509>>,
}
#[derive(Clone)]
pub struct TunnelCertHandler {
pub db: TypedPatchDb<TunnelDatabase>,
pub crypto_provider: Arc<CryptoProvider>,
}
impl<'a, A> TlsHandler<'a, A> for TunnelCertHandler
where
A: Accept + 'a,
<A as Accept>::Metadata: Send + Sync,
{
async fn get_config(
&'a mut self,
_: &'a ClientHello<'a>,
_: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig> {
let cert_info = self
.db
.peek()
.await
.as_webserver()
.as_certificate()
.de()
.log_err()??;
let cert_chain: Vec<_> = cert_info
.cert
.0
.iter()
.map(|c| Ok::<_, Error>(CertificateDer::from(c.to_der()?)))
.collect::<Result<_, _>>()
.log_err()?;
let cert_key = cert_info.key.0.private_key_to_pkcs8().log_err()?;
Some(
ServerConfig::builder_with_provider(self.crypto_provider.clone())
.with_safe_default_protocol_versions()
.log_err()?
.with_no_client_auth()
.with_single_cert(
cert_chain,
PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert_key)),
)
.log_err()?,
)
}
}
pub fn web_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand(
"init",
from_fn_async_local(init_web)
.no_display()
.with_about("Initialize the webserver"),
)
.subcommand(
"set-listen",
from_fn_async(set_listen)
.no_display()
.with_about("Set the listen address for the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"get-listen",
from_fn_async(get_listen)
.with_display_serializable()
.with_about("Get the listen address for the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"get-available-ips",
from_fn_async(get_available_ips)
.with_display_serializable()
.with_about("Get available IP addresses to bind to")
.with_call_remote::<CliContext>(),
)
.subcommand(
"import-certificate",
from_fn_async(import_certificate_rpc).no_cli(),
)
.subcommand(
"import-certificate",
from_fn_async_local(import_certificate_cli)
.no_display()
.with_about("Import a certificate to use for the webserver"),
)
.subcommand(
"generate-certificate",
from_fn_async(generate_certificate)
.with_about("Generate a self signed certificaet to use for the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"get-certificate",
from_fn_async(get_certificate)
.with_display_serializable()
.with_custom_display_fn(|HandlerArgs { params, .. }, res| {
if let Some(format) = params.format {
return display_serializable(format, res);
}
if let Some(res) = res {
println!("{res}");
}
Ok(())
})
.with_about("Get the certificate for the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"enable",
from_fn_async(enable_web)
.with_about("Enable the webserver")
.no_display()
.with_call_remote::<CliContext>(),
)
.subcommand(
"disable",
from_fn_async(disable_web)
.no_display()
.with_about("Disable the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"reset",
from_fn_async(reset_web)
.no_display()
.with_about("Reset the webserver")
.with_call_remote::<CliContext>(),
)
}
pub async fn import_certificate_rpc(
ctx: TunnelContext,
cert_data: TunnelCertData,
) -> Result<(), Error> {
ctx.db
.mutate(|db| {
db.as_webserver_mut()
.as_certificate_mut()
.ser(&Some(cert_data))
})
.await
.result?;
Ok(())
}
pub async fn import_certificate_cli(
HandlerArgs {
context,
parent_method,
method,
..
}: HandlerArgs<CliContext>,
) -> Result<(), Error> {
let mut key_string = String::new();
let key: Pem<PKey<Private>> =
prompt_multiline("Please paste in your PEM encoded private key: ", |line| {
key_string.push_str(&line);
key_string.push_str("\n");
if line.trim().starts_with("-----END") {
return key_string.parse().map(Some).map_err(|e| {
key_string.truncate(0);
e
});
}
Ok(None)
})
.await?;
let mut chain = Vec::<X509>::new();
let mut cert_string = String::new();
prompt_multiline(
concat!(
"Please paste in your PEM encoded certificate",
" (or certificate chain):"
),
|line| {
cert_string.push_str(&line);
cert_string.push_str("\n");
if line.trim().starts_with("-----END") {
let cert = cert_string.parse::<Pem<X509>>();
cert_string.truncate(0);
let cert = cert?;
let pubkey = cert.0.public_key()?;
if chain.is_empty() {
if !key.public_eq(&pubkey) {
return Err(Error::new(
eyre!("Certificate does not match key!"),
ErrorKind::InvalidSignature,
));
}
}
if let Some(prev) = chain.last() {
if !prev.verify(&pubkey)? {
return Err(Error::new(
eyre!(concat!(
"Invalid Fullchain: ",
"Previous cert was not signed by this certificate's key"
)),
ErrorKind::InvalidSignature,
));
}
}
let is_root = cert.0.verify(&pubkey)?;
chain.push(cert.0);
if is_root { Ok(Some(())) } else { Ok(None) }
} else {
Ok(None)
}
},
)
.await?;
context
.call_remote::<TunnelContext>(
&parent_method.iter().chain(method.iter()).join("."),
to_value(&TunnelCertData {
key,
cert: Pem(chain),
})?,
)
.await?;
Ok(())
}
#[derive(Debug, Deserialize, Serialize, Parser)]
pub struct GenerateCertParams {
#[arg(help = "Subject Alternative Name(s)")]
pub subject: Vec<InternedString>,
}
pub async fn generate_certificate(
ctx: TunnelContext,
GenerateCertParams { subject }: GenerateCertParams,
) -> Result<Pem<X509>, Error> {
let saninfo = SANInfo::new(&subject.into_iter().collect());
let key = crate::net::ssl::generate_key()?;
let cert = crate::net::ssl::make_self_signed((&key, &saninfo))?;
ctx.db
.mutate(|db| {
db.as_webserver_mut()
.as_certificate_mut()
.ser(&Some(TunnelCertData {
key: Pem(key),
cert: Pem(vec![cert.clone()]),
}))
})
.await
.result?;
Ok(Pem(cert))
}
pub async fn get_certificate(ctx: TunnelContext) -> Result<Option<Pem<Vec<X509>>>, Error> {
ctx.db
.peek()
.await
.as_webserver()
.as_certificate()
.de()?
.map(|cert_data| Ok(cert_data.cert))
.transpose()
}
#[derive(Debug, Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct SetListenParams {
pub listen: SocketAddr,
}
pub async fn set_listen(
ctx: TunnelContext,
SetListenParams { listen }: SetListenParams,
) -> Result<(), Error> {
// Validate that the address is available to bind
tokio::net::TcpListener::bind(listen)
.await
.with_kind(ErrorKind::Network)
.with_ctx(|_| {
(
ErrorKind::Network,
format!("{} is not available to bind to", listen),
)
})?;
ctx.db
.mutate(|db| {
db.as_webserver_mut().as_listen_mut().ser(&Some(listen))?;
Ok(())
})
.await
.result
}
pub async fn get_listen(ctx: TunnelContext) -> Result<Option<SocketAddr>, Error> {
ctx.db.peek().await.as_webserver().as_listen().de()
}
pub async fn get_available_ips(ctx: TunnelContext) -> Result<Vec<IpAddr>, Error> {
let ips = ctx.net_iface.peek(|interfaces| {
interfaces
.values()
.flat_map(|info| {
info.ip_info
.iter()
.flat_map(|ip_info| ip_info.subnets.iter().map(|subnet| subnet.addr()))
})
.collect::<Vec<IpAddr>>()
});
Ok(ips)
}
pub async fn enable_web(ctx: TunnelContext) -> Result<(), Error> {
ctx.db
.mutate(|db| {
if db.as_webserver().as_listen().transpose_ref().is_none() {
return Err(Error::new(
eyre!("Listen is not set"),
ErrorKind::ParseNetAddress,
));
}
if db.as_webserver().as_certificate().transpose_ref().is_none() {
return Err(Error::new(
eyre!("Certificate is not set"),
ErrorKind::OpenSsl,
));
}
if db.as_password().transpose_ref().is_none() {
return Err(Error::new(
eyre!("Password is not set"),
ErrorKind::Authorization,
));
};
db.as_webserver_mut().as_enabled_mut().ser(&true)
})
.await
.result
}
pub async fn disable_web(ctx: TunnelContext) -> Result<(), Error> {
ctx.db
.mutate(|db| db.as_webserver_mut().as_enabled_mut().ser(&false))
.await
.result
}
pub async fn reset_web(ctx: TunnelContext) -> Result<(), Error> {
ctx.db
.mutate(|db| {
db.as_webserver_mut().as_enabled_mut().ser(&false)?;
db.as_webserver_mut().as_listen_mut().ser(&None)?;
db.as_webserver_mut().as_certificate_mut().ser(&None)?;
db.as_password_mut().ser(&None)?;
Ok(())
})
.await
.result
}
fn is_valid_domain(domain: &str) -> bool {
if domain.is_empty() || domain.len() > 253 || domain.starts_with('.') || domain.ends_with('.') {
return false;
}
let labels: Vec<&str> = domain.split('.').collect();
for label in labels {
if label.is_empty() || label.len() > 63 {
return false;
}
if !label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') {
return false;
}
if label.chars().next().map_or(true, |c| c == '-')
|| label.chars().next_back().map_or(true, |c| c == '-')
{
return false;
}
}
true
}
pub async fn init_web(ctx: CliContext) -> Result<(), Error> {
let mut password = None;
loop {
match ctx
.call_remote::<TunnelContext>("web.enable", json!({}))
.await
{
Ok(_) => {
let listen = from_value::<SocketAddr>(
ctx.call_remote::<TunnelContext>("web.get-listen", json!({}))
.await?,
)?;
println!("✅ Success! ✅");
println!(
"The webserver is running. Below is your URL{} and SSL certificate.",
if password.is_some() {
", password,"
} else {
""
}
);
println!();
println!("🌐 URL");
println!("https://{listen}");
if listen.ip().is_unspecified() {
println!(concat!(
"Note: this is the unspecified address. ",
"This means you can use any IP address available to this device to connect. ",
"Using the above address as-is will only work from this device."
));
} else if listen.ip().is_loopback() {
println!(concat!(
"Note: this is a loopback address. ",
"This is only recommended if you are planning to run a proxy in front of the web ui. ",
"Using the above address as-is will only work from this device."
));
}
println!();
if let Some(password) = password {
println!("🔒 Password");
println!("{password}");
println!();
println!(concat!(
"If you lose or forget your password, you can reset it using the command: ",
"start-tunnel auth reset-password"
));
} else {
println!(concat!(
"Your password was set up previously. ",
"If you don't remember it, you can reset it using the command: ",
"start-tunnel auth reset-password"
));
}
println!();
let cert = from_value::<Pem<Vec<X509>>>(
ctx.call_remote::<TunnelContext>("web.get-certificate", json!({}))
.await?,
)?;
println!("📝 SSL Certificate:");
print!("{cert}");
println!(concat!(
"If you haven't already, ",
"trust the certificate in your system keychain and/or browser."
));
return Ok(());
}
Err(e) if e.kind == ErrorKind::ParseNetAddress => {
println!("Select the IP address at which to host the web interface:");
let mut suggested_addrs = from_value::<Vec<IpAddr>>(
ctx.call_remote::<TunnelContext>("web.get-available-ips", json!({}))
.await?,
)?;
suggested_addrs.sort_by_cached_key(|a| match a {
IpAddr::V4(a) => {
if a.is_loopback() {
3
} else if a.is_private() {
2
} else {
0
}
}
IpAddr::V6(a) => {
if a.is_loopback() {
5
} else if a.is_unicast_link_local() {
4
} else {
1
}
}
});
let ip = if suggested_addrs.is_empty() {
prompt("Listen Address: ", parse_as::<IpAddr>("IP Address"), None).await?
} else if suggested_addrs.len() > 16 {
prompt(
&format!("Listen Address [{}]: ", suggested_addrs[0]),
parse_as::<IpAddr>("IP Address"),
Some(suggested_addrs[0]),
)
.await?
} else {
*choose_custom_display("Listen Address:", &suggested_addrs, |a| match a {
a if a.is_loopback() => {
format!("{a} (Loopback Address: only use if planning to proxy traffic)")
}
IpAddr::V4(a) if a.is_private() => {
format!("{a} (Private Address: only available from Local Area Network)")
}
IpAddr::V6(a) if a.is_unicast_link_local() => {
format!(
"[{a}] (Private Address: only available from Local Area Network)"
)
}
IpAddr::V6(a) => format!("[{a}]"),
a => a.to_string(),
})
.await?
};
println!(concat!(
"Enter the port at which to host the web interface. ",
"The recommended default is 8443. ",
"If you change the default, choose an uncommon port to avoid conflicts: "
));
let port = prompt("Port [8443]: ", parse_as::<u16>("port"), Some(8443)).await?;
let listen = SocketAddr::new(ip, port);
ctx.call_remote::<TunnelContext>(
"web.set-listen",
to_value(&SetListenParams { listen })?,
)
.await?;
println!();
}
Err(e) if e.kind == ErrorKind::OpenSsl => {
enum Choice {
Generate,
Provide,
}
impl std::fmt::Display for Choice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Generate => write!(f, "Generate a Self Signed Certificate"),
Self::Provide => write!(f, "Provide your own certificate and key"),
}
}
}
let options = vec![Choice::Generate, Choice::Provide];
let choice = choose(
concat!(
"Select whether to autogenerate a self-signed SSL certificate ",
"or provide your own certificate and key:"
),
&options,
)
.await?;
match choice {
Choice::Generate => {
let listen = from_value::<Option<SocketAddr>>(
ctx.call_remote::<TunnelContext>("web.get-listen", json!({}))
.await?,
)?
.filter(|a| !a.ip().is_unspecified());
let default_prompt = if let Some(listen) = listen {
format!("Subject Alternative Name(s) [{}]: ", listen.ip())
} else {
"Subject Alternative Name(s): ".to_string()
};
println!(
"List all IP addresses and domains for which to sign the certificate, separated by commas."
);
let san_info = prompt(
&default_prompt,
|s| {
s.split(",")
.map(|s| {
let s = s.trim();
if let Ok(ip) = s.parse::<IpAddr>() {
Ok(InternedString::from_display(&ip))
} else if is_valid_domain(s) {
Ok(s.into())
} else {
Err(format!("{s} is not a valid ip address or domain"))
}
})
.collect()
},
listen.map(|l| vec![InternedString::from_display(&l.ip())]),
)
.await?;
ctx.call_remote::<TunnelContext>(
"web.generate-certificate",
to_value(&GenerateCertParams { subject: san_info })?,
)
.await?;
}
Choice::Provide => {
import_certificate_cli(HandlerArgs {
context: ctx.clone(),
parent_method: vec!["web", "import-certificate"].into(),
method: VecDeque::new(),
params: Empty {},
inherited_params: Empty {},
raw_params: json!({}),
})
.await?;
}
}
println!();
}
Err(e) if e.kind == ErrorKind::Authorization => {
println!("Generating a random password...");
let params = SetPasswordParams {
password: base32::encode(
base32::Alphabet::Rfc4648Lower { padding: false },
&rand::random::<[u8; 16]>(),
),
};
ctx.call_remote::<TunnelContext>("auth.set-password", to_value(&params)?)
.await?;
password = Some(params.password);
println!();
}
Err(e) => return Err(e.into()),
}
}
}

View File

@@ -1,19 +1,22 @@
use std::collections::BTreeMap;
use std::net::{Ipv4Addr, SocketAddrV4};
use std::net::{Ipv4Addr, SocketAddr};
use ed25519_dalek::{SigningKey, VerifyingKey};
use imbl_value::InternedString;
use ipnet::Ipv4Net;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use ts_rs::TS;
use x25519_dalek::{PublicKey, StaticSecret};
use crate::prelude::*;
use crate::util::Invoke;
use crate::util::io::write_file_atomic;
use crate::util::serde::Base64;
#[derive(Deserialize, Serialize, HasModel)]
pub const WIREGUARD_INTERFACE_NAME: &str = "wg-start-tunnel";
#[derive(Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct WgServer {
@@ -37,7 +40,7 @@ impl WgServer {
pub async fn sync(&self) -> Result<(), Error> {
Command::new("wg-quick")
.arg("down")
.arg("wg0")
.arg(WIREGUARD_INTERFACE_NAME)
.invoke(ErrorKind::Network)
.await
.or_else(|e| {
@@ -49,21 +52,23 @@ impl WgServer {
}
})?;
write_file_atomic(
"/etc/wireguard/wg0.conf",
const_format::formatcp!("/etc/wireguard/{WIREGUARD_INTERFACE_NAME}.conf"),
self.server_config().to_string().as_bytes(),
)
.await?;
Command::new("wg-quick")
.arg("up")
.arg("wg0")
.arg(WIREGUARD_INTERFACE_NAME)
.invoke(ErrorKind::Network)
.await?;
Ok(())
}
}
#[derive(Default, Deserialize, Serialize)]
pub struct WgSubnetMap(pub BTreeMap<Ipv4Net, WgSubnetConfig>);
#[derive(Default, Deserialize, Serialize, TS)]
pub struct WgSubnetMap(
#[ts(as = "BTreeMap::<String, WgSubnetConfig>")] pub BTreeMap<Ipv4Net, WgSubnetConfig>,
);
impl Map for WgSubnetMap {
type Key = Ipv4Net;
type Value = WgSubnetConfig;
@@ -75,35 +80,41 @@ impl Map for WgSubnetMap {
}
}
#[derive(Default, Deserialize, Serialize, HasModel)]
#[derive(Default, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct WgSubnetConfig {
pub default_forward_target: Option<Ipv4Addr>,
pub clients: BTreeMap<Ipv4Addr, WgConfig>,
pub name: InternedString,
pub clients: WgSubnetClients,
}
impl WgSubnetConfig {
pub fn new() -> Self {
Self::default()
}
pub fn add_client<'a>(
&'a mut self,
subnet: Ipv4Net,
) -> Result<(Ipv4Addr, &'a WgConfig), Error> {
let addr = subnet
.hosts()
.find(|a| !self.clients.contains_key(a))
.ok_or_else(|| Error::new(eyre!("subnet exhausted"), ErrorKind::Network))?;
let config = self.clients.entry(addr).or_insert(WgConfig::generate());
Ok((addr, config))
pub fn new(name: InternedString) -> Self {
Self {
name,
..Self::default()
}
}
}
pub struct WgKey(SigningKey);
#[derive(Default, Deserialize, Serialize, TS)]
pub struct WgSubnetClients(pub BTreeMap<Ipv4Addr, WgConfig>);
impl Map for WgSubnetClients {
type Key = Ipv4Addr;
type Value = WgConfig;
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
Self::key_string(key)
}
fn key_string(key: &Self::Key) -> Result<InternedString, Error> {
Ok(InternedString::from_display(key))
}
}
#[derive(Clone)]
pub struct WgKey(StaticSecret);
impl WgKey {
pub fn generate() -> Self {
Self(SigningKey::generate(
&mut ssh_key::rand_core::OsRng::default(),
Self(StaticSecret::random_from_rng(
ssh_key::rand_core::OsRng::default(),
))
}
}
@@ -113,33 +124,39 @@ impl AsRef<[u8]> for WgKey {
}
}
impl TryFrom<Vec<u8>> for WgKey {
type Error = ed25519_dalek::SignatureError;
type Error = Error;
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
Ok(Self(value.as_slice().try_into()?))
Ok(Self(
<[u8; 32]>::try_from(value)
.map_err(|_| Error::new(eyre!("invalid key length"), ErrorKind::Deserialization))?
.into(),
))
}
}
impl std::ops::Deref for WgKey {
type Target = SigningKey;
type Target = StaticSecret;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Base64<WgKey> {
pub fn verifying_key(&self) -> Base64<VerifyingKey> {
Base64(self.0.verifying_key())
pub fn verifying_key(&self) -> Base64<PublicKey> {
Base64((&*self.0).into())
}
}
#[derive(Deserialize, Serialize, HasModel)]
#[derive(Clone, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct WgConfig {
pub name: InternedString,
pub key: Base64<WgKey>,
pub psk: Base64<[u8; 32]>,
}
impl WgConfig {
pub fn generate() -> Self {
pub fn generate(name: InternedString) -> Self {
Self {
name,
key: Base64(WgKey::generate()),
psk: Base64(rand::random()),
}
@@ -150,12 +167,12 @@ impl WgConfig {
client_addr: addr,
}
}
pub fn client_config<'a>(
&'a self,
pub fn client_config(
self,
addr: Ipv4Addr,
server_pubkey: Base64<VerifyingKey>,
server_addr: SocketAddrV4,
) -> ClientConfig<'a> {
server_pubkey: Base64<PublicKey>,
server_addr: SocketAddr,
) -> ClientConfig {
ClientConfig {
client_config: self,
client_addr: addr,
@@ -181,19 +198,33 @@ impl<'a> std::fmt::Display for ServerPeerConfig<'a> {
}
}
pub struct ClientConfig<'a> {
client_config: &'a WgConfig,
client_addr: Ipv4Addr,
server_pubkey: Base64<VerifyingKey>,
server_addr: SocketAddrV4,
fn deserialize_verifying_key<'de, D>(deserializer: D) -> Result<Base64<PublicKey>, D::Error>
where
D: serde::Deserializer<'de>,
{
Base64::<Vec<u8>>::deserialize(deserializer).and_then(|b| {
Ok(Base64(PublicKey::from(<[u8; 32]>::try_from(b.0).map_err(
|e: Vec<u8>| serde::de::Error::invalid_length(e.len(), &"a 32 byte base64 string"),
)?)))
})
}
impl<'a> std::fmt::Display for ClientConfig<'a> {
#[derive(Clone, Serialize, Deserialize)]
pub struct ClientConfig {
client_config: WgConfig,
client_addr: Ipv4Addr,
#[serde(deserialize_with = "deserialize_verifying_key")]
server_pubkey: Base64<PublicKey>,
server_addr: SocketAddr,
}
impl std::fmt::Display for ClientConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
include_str!("./client.conf.template"),
name = self.client_config.name,
privkey = self.client_config.key.to_padded_string(),
psk = self.client_config.psk,
psk = self.client_config.psk.to_padded_string(),
addr = self.client_addr,
server_pubkey = self.server_pubkey.to_padded_string(),
server_addr = self.server_addr,
@@ -212,7 +243,7 @@ impl<'a> std::fmt::Display for ServerConfig<'a> {
server_port = server.port,
server_privkey = server.key.to_padded_string(),
)?;
for (addr, peer) in server.subnets.0.values().flat_map(|s| &s.clients) {
for (addr, peer) in server.subnets.0.values().flat_map(|s| &s.clients.0) {
write!(f, "{}", peer.server_peer_config(*addr))?;
}
Ok(())