Refactor/project structure (#3085)

* refactor project structure

* environment-based default registry

* fix tests

* update build container

* use docker platform for iso build emulation

* simplify compat

* Fix docker platform spec in run-compat.sh

* handle riscv compat

* fix bug with dep error exists attr

* undo removal of sorting

* use qemu for iso stage

---------

Co-authored-by: Mariusz Kogen <k0gen@pm.me>
Co-authored-by: Matt Hill <mattnine@protonmail.com>
This commit is contained in:
Aiden McClelland
2025-12-22 13:39:38 -07:00
committed by GitHub
parent eda08d5b0f
commit 96ae532879
389 changed files with 744 additions and 4005 deletions

508
core/src/tunnel/api.rs Normal file
View File

@@ -0,0 +1,508 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use clap::Parser;
use imbl_value::InternedString;
use ipnet::Ipv4Net;
use rpc_toolkit::{Context, Empty, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use crate::context::CliContext;
use crate::db::model::public::NetworkInterfaceType;
use crate::net::forward::add_iptables_rule;
use crate::prelude::*;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::wg::{WIREGUARD_INTERFACE_NAME, WgConfig, WgSubnetClients, WgSubnetConfig};
use crate::util::serde::{HandlerExtSerde, display_serializable};
pub fn tunnel_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand("web", super::web::web_api::<C>())
.subcommand(
"db",
super::db::db_api::<C>()
.with_about("Commands to interact with the db i.e. dump and apply"),
)
.subcommand(
"auth",
super::auth::auth_api::<C>().with_about("Add or remove authorized clients"),
)
.subcommand(
"subnet",
subnet_api::<C>().with_about("Add, remove, or modify subnets"),
)
.subcommand(
"device",
device_api::<C>().with_about("Add, remove, or list devices in subnets"),
)
.subcommand(
"port-forward",
ParentHandler::<C>::new()
.subcommand(
"add",
from_fn_async(add_forward)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Add a new port forward")
.with_call_remote::<CliContext>(),
)
.subcommand(
"remove",
from_fn_async(remove_forward)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Remove a port forward")
.with_call_remote::<CliContext>(),
),
)
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct SubnetParams {
subnet: Ipv4Net,
}
pub fn subnet_api<C: Context>() -> ParentHandler<C, SubnetParams> {
ParentHandler::new()
.subcommand(
"add",
from_fn_async(add_subnet)
.with_metadata("sync_db", Value::Bool(true))
.with_inherited(|a, _| a)
.no_display()
.with_about("Add a new subnet")
.with_call_remote::<CliContext>(),
)
.subcommand(
"remove",
from_fn_async(remove_subnet)
.with_metadata("sync_db", Value::Bool(true))
.with_inherited(|a, _| a)
.no_display()
.with_about("Remove a subnet")
.with_call_remote::<CliContext>(),
)
}
pub fn device_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand(
"add",
from_fn_async(add_device)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Add a device to a subnet")
.with_call_remote::<CliContext>(),
)
.subcommand(
"remove",
from_fn_async(remove_device)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Remove a device from a subnet")
.with_call_remote::<CliContext>(),
)
.subcommand(
"list",
from_fn_async(list_devices)
.with_display_serializable()
.with_custom_display_fn(|HandlerArgs { params, .. }, res| {
use prettytable::*;
if let Some(format) = params.format {
return display_serializable(format, res);
}
let mut table = Table::new();
table.add_row(row![bc => "NAME", "IP", "PUBLIC KEY"]);
for (ip, config) in res.clients.0 {
table.add_row(row![config.name, ip, config.key.verifying_key()]);
}
table.print_tty(false)?;
Ok(())
})
.with_about("List devices in a subnet")
.with_call_remote::<CliContext>(),
)
.subcommand(
"show-config",
from_fn_async(show_config)
.with_about("Show the WireGuard configuration for a device")
.with_call_remote::<CliContext>(),
)
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct AddSubnetParams {
name: InternedString,
}
pub async fn add_subnet(
ctx: TunnelContext,
AddSubnetParams { name }: AddSubnetParams,
SubnetParams { mut subnet }: SubnetParams,
) -> Result<(), Error> {
if subnet.prefix_len() > 24 {
return Err(Error::new(
eyre!("invalid subnet"),
ErrorKind::InvalidRequest,
));
}
let addr = subnet
.hosts()
.next()
.ok_or_else(|| Error::new(eyre!("invalid subnet"), ErrorKind::InvalidRequest))?;
subnet = Ipv4Net::new_assert(addr, subnet.prefix_len());
let server = ctx
.db
.mutate(|db| {
let map = db.as_wg_mut().as_subnets_mut();
if let Some(s) = map
.keys()?
.into_iter()
.find(|s| s != &subnet && (s.contains(&subnet) || subnet.contains(s)))
{
return Err(Error::new(
eyre!("{subnet} overlaps with existing subnet {s}"),
ErrorKind::InvalidRequest,
));
}
map.upsert(&subnet, || {
Ok(WgSubnetConfig::new(InternedString::default()))
})?
.as_name_mut()
.ser(&name)?;
db.as_wg().de()
})
.await
.result?;
server.sync().await?;
for iface in ctx.net_iface.peek(|i| {
i.iter()
.filter(|(_, info)| {
info.ip_info.as_ref().map_or(false, |i| {
i.device_type != Some(NetworkInterfaceType::Loopback)
})
})
.map(|(name, _)| name)
.filter(|id| id.as_str() != WIREGUARD_INTERFACE_NAME)
.cloned()
.collect::<Vec<_>>()
}) {
add_iptables_rule(
true,
false,
&[
"POSTROUTING",
"-s",
&subnet.trunc().to_string(),
"-o",
iface.as_str(),
"-j",
"MASQUERADE",
],
)
.await?;
}
Ok(())
}
pub async fn remove_subnet(
ctx: TunnelContext,
_: Empty,
SubnetParams { subnet }: SubnetParams,
) -> Result<(), Error> {
let (server, keep) = ctx
.db
.mutate(|db| {
db.as_wg_mut().as_subnets_mut().remove(&subnet)?;
Ok((db.as_wg().de()?, db.gc_forwards()?))
})
.await
.result?;
server.sync().await?;
ctx.gc_forwards(&keep).await?;
for iface in ctx.net_iface.peek(|i| {
i.iter()
.filter(|(_, info)| {
info.ip_info.as_ref().map_or(false, |i| {
i.device_type != Some(NetworkInterfaceType::Loopback)
})
})
.map(|(name, _)| name)
.filter(|id| id.as_str() != WIREGUARD_INTERFACE_NAME)
.cloned()
.collect::<Vec<_>>()
}) {
add_iptables_rule(
true,
true,
&[
"POSTROUTING",
"-s",
&subnet.trunc().to_string(),
"-o",
iface.as_str(),
"-j",
"MASQUERADE",
],
)
.await?;
}
Ok(())
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct AddDeviceParams {
subnet: Ipv4Net,
name: InternedString,
ip: Option<Ipv4Addr>,
}
pub async fn add_device(
ctx: TunnelContext,
AddDeviceParams { subnet, name, ip }: AddDeviceParams,
) -> Result<(), Error> {
let server = ctx
.db
.mutate(|db| {
db.as_wg_mut()
.as_subnets_mut()
.as_idx_mut(&subnet)
.or_not_found(&subnet)?
.as_clients_mut()
.mutate(|WgSubnetClients(clients)| {
let ip = if let Some(ip) = ip {
ip
} else {
subnet
.hosts()
.find(|ip| !clients.contains_key(ip) && *ip != subnet.addr())
.ok_or_else(|| {
Error::new(
eyre!("no available ips in subnet"),
ErrorKind::InvalidRequest,
)
})?
};
if ip.octets()[3] == 0 || ip.octets()[3] == 255 {
return Err(Error::new(eyre!("invalid ip"), ErrorKind::InvalidRequest));
}
if ip == subnet.addr() {
return Err(Error::new(eyre!("invalid ip"), ErrorKind::InvalidRequest));
}
if !subnet.contains(&ip) {
return Err(Error::new(
eyre!("ip not in subnet"),
ErrorKind::InvalidRequest,
));
}
let client = clients
.entry(ip)
.or_insert_with(|| WgConfig::generate(name.clone()));
client.name = name;
Ok(())
})?;
db.as_wg().de()
})
.await
.result?;
server.sync().await
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct RemoveDeviceParams {
subnet: Ipv4Net,
ip: Ipv4Addr,
}
pub async fn remove_device(
ctx: TunnelContext,
RemoveDeviceParams { subnet, ip }: RemoveDeviceParams,
) -> Result<(), Error> {
let (server, keep) = ctx
.db
.mutate(|db| {
db.as_wg_mut()
.as_subnets_mut()
.as_idx_mut(&subnet)
.or_not_found(&subnet)?
.as_clients_mut()
.remove(&ip)?
.or_not_found(&ip)?;
Ok((db.as_wg().de()?, db.gc_forwards()?))
})
.await
.result?;
server.sync().await?;
ctx.gc_forwards(&keep).await
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct ListDevicesParams {
subnet: Ipv4Net,
}
pub async fn list_devices(
ctx: TunnelContext,
ListDevicesParams { subnet }: ListDevicesParams,
) -> Result<WgSubnetConfig, Error> {
ctx.db
.peek()
.await
.as_wg()
.as_subnets()
.as_idx(&subnet)
.or_not_found(&subnet)?
.de()
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct ShowConfigParams {
subnet: Ipv4Net,
ip: Ipv4Addr,
wan_addr: Option<IpAddr>,
#[serde(rename = "__ConnectInfo_local_addr")]
#[arg(skip)]
local_addr: Option<SocketAddr>,
}
pub async fn show_config(
ctx: TunnelContext,
ShowConfigParams {
subnet,
ip,
wan_addr,
local_addr,
}: ShowConfigParams,
) -> Result<String, Error> {
let peek = ctx.db.peek().await;
let wg = peek.as_wg();
let client = wg
.as_subnets()
.as_idx(&subnet)
.or_not_found(&subnet)?
.as_clients()
.as_idx(&ip)
.or_not_found(&ip)?
.de()?;
let wan_addr = if let Some(wan_addr) = wan_addr.or(local_addr.map(|a| a.ip())).filter(|ip| {
!ip.is_loopback()
&& !match ip {
IpAddr::V4(ipv4) => ipv4.is_private() || ipv4.is_link_local(),
IpAddr::V6(ipv6) => ipv6.is_unique_local() || ipv6.is_unicast_link_local(),
}
}) {
wan_addr
} else if let Some(webserver) = peek.as_webserver().as_listen().de()? {
webserver.ip()
} else {
ctx.net_iface
.peek(|i| {
i.iter().find_map(|(_, info)| {
info.ip_info
.as_ref()
.filter(|_| info.public())
.iter()
.find_map(|info| info.subnets.iter().next())
.copied()
})
})
.or_not_found("a public IP address")?
.addr()
};
Ok(client
.client_config(
ip,
subnet,
wg.as_key().de()?.verifying_key(),
(wan_addr, wg.as_port().de()?).into(),
)
.to_string())
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct AddPortForwardParams {
source: SocketAddrV4,
target: SocketAddrV4,
}
pub async fn add_forward(
ctx: TunnelContext,
AddPortForwardParams { source, target }: AddPortForwardParams,
) -> Result<(), Error> {
let prefix = ctx
.net_iface
.peek(|i| {
i.iter()
.find_map(|(_, i)| {
i.ip_info.as_ref().and_then(|i| {
i.subnets
.iter()
.find(|s| s.contains(&IpAddr::from(*target.ip())))
})
})
.cloned()
})
.map(|s| s.prefix_len())
.unwrap_or(32);
let rc = ctx.forward.add_forward(source, target, prefix).await?;
ctx.active_forwards.mutate(|m| {
m.insert(source, rc);
});
ctx.db
.mutate(|db| {
db.as_port_forwards_mut()
.insert(&source, &target)
.and_then(|replaced| {
if replaced.is_some() {
Err(Error::new(
eyre!("Port forward from {source} already exists"),
ErrorKind::InvalidRequest,
))
} else {
Ok(())
}
})
})
.await
.result?;
Ok(())
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct RemovePortForwardParams {
source: SocketAddrV4,
}
pub async fn remove_forward(
ctx: TunnelContext,
RemovePortForwardParams { source, .. }: RemovePortForwardParams,
) -> Result<(), Error> {
ctx.db
.mutate(|db| db.as_port_forwards_mut().remove(&source))
.await
.result?;
if let Some(rc) = ctx.active_forwards.mutate(|m| m.remove(&source)) {
drop(rc);
ctx.forward.gc().await?;
}
Ok(())
}

299
core/src/tunnel/auth.rs Normal file
View File

@@ -0,0 +1,299 @@
use clap::Parser;
use imbl::HashMap;
use imbl_value::InternedString;
use itertools::Itertools;
use patch_db::HasModel;
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use crate::auth::{Sessions, check_password};
use crate::context::CliContext;
use crate::middleware::auth::DbContext;
use crate::middleware::auth::local::LocalAuthContext;
use crate::middleware::auth::session::SessionAuthContext;
use crate::middleware::auth::signature::SignatureAuthContext;
use crate::prelude::*;
use crate::rpc_continuations::OpenAuthedContinuations;
use crate::sign::AnyVerifyingKey;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::db::TunnelDatabase;
use crate::util::serde::{HandlerExtSerde, display_serializable};
use crate::util::sync::SyncMutex;
impl DbContext for TunnelContext {
type Database = TunnelDatabase;
fn db(&self) -> &TypedPatchDb<Self::Database> {
&self.db
}
}
impl SignatureAuthContext for TunnelContext {
type AdditionalMetadata = ();
type CheckPubkeyRes = ();
async fn sig_context(
&self,
) -> impl IntoIterator<Item = Result<impl AsRef<str> + Send, Error>> + Send {
let peek = self.db().peek().await;
peek.as_webserver()
.as_listen()
.de()
.map(|a| a.as_ref().map(InternedString::from_display))
.transpose()
.into_iter()
.chain(
std::iter::once_with(move || {
peek.as_webserver()
.as_certificate()
.de()
.ok()
.flatten()
.and_then(|cert_data| cert_data.cert.0.first().cloned())
.and_then(|cert| cert.subject_alt_names())
.into_iter()
.flatten()
.filter_map(|san| {
san.dnsname().map(InternedString::from).or_else(|| {
san.ipaddress().and_then(|ip_bytes| {
let ip: std::net::IpAddr = match ip_bytes.len() {
4 => std::net::IpAddr::V4(std::net::Ipv4Addr::from(
<[u8; 4]>::try_from(ip_bytes).ok()?,
)),
16 => std::net::IpAddr::V6(std::net::Ipv6Addr::from(
<[u8; 16]>::try_from(ip_bytes).ok()?,
)),
_ => return None,
};
Some(InternedString::from_display(&ip))
})
})
})
.map(Ok)
.collect::<Vec<_>>()
})
.flatten(),
)
}
fn check_pubkey(
db: &Model<Self::Database>,
pubkey: Option<&crate::sign::AnyVerifyingKey>,
_: Self::AdditionalMetadata,
) -> Result<Self::CheckPubkeyRes, Error> {
if let Some(pubkey) = pubkey {
if db.as_auth_pubkeys().de()?.contains_key(pubkey) {
return Ok(());
}
}
Err(Error::new(
eyre!("Key is not authorized"),
ErrorKind::IncorrectPassword,
))
}
async fn post_auth_hook(
&self,
_: Self::CheckPubkeyRes,
_: &rpc_toolkit::RpcRequest,
) -> Result<(), Error> {
Ok(())
}
}
impl LocalAuthContext for TunnelContext {
const LOCAL_AUTH_COOKIE_PATH: &str = "/run/startos/tunnel.authcookie";
const LOCAL_AUTH_COOKIE_OWNERSHIP: &str = "root:root";
}
impl SessionAuthContext for TunnelContext {
fn access_sessions(db: &mut Model<Self::Database>) -> &mut Model<crate::auth::Sessions> {
db.as_sessions_mut()
}
fn ephemeral_sessions(&self) -> &SyncMutex<Sessions> {
&self.ephemeral_sessions
}
fn open_authed_continuations(&self) -> &OpenAuthedContinuations<Option<InternedString>> {
&self.open_authed_continuations
}
fn check_password(db: &Model<Self::Database>, password: &str) -> Result<(), Error> {
check_password(&db.as_password().de()?.unwrap_or_default(), password)
}
}
#[derive(Clone, Debug, Deserialize, Serialize, HasModel, TS, Parser)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct SignerInfo {
pub name: InternedString,
}
pub fn auth_api<C: Context>() -> ParentHandler<C> {
crate::auth::auth::<C, TunnelContext>()
.subcommand("set-password", from_fn_async(set_password_rpc).no_cli())
.subcommand(
"set-password",
from_fn_async(set_password_cli)
.with_about("Set user interface password")
.no_display(),
)
.subcommand(
"reset-password",
from_fn_async(reset_password)
.with_about("Reset user interface password")
.no_display(),
)
.subcommand(
"key",
ParentHandler::<C>::new()
.subcommand(
"add",
from_fn_async(add_key)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Add a new authorized key")
.with_call_remote::<CliContext>(),
)
.subcommand(
"remove",
from_fn_async(remove_key)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Remove an authorized key")
.with_call_remote::<CliContext>(),
)
.subcommand(
"list",
from_fn_async(list_keys)
.with_metadata("sync_db", Value::Bool(true))
.with_display_serializable()
.with_custom_display_fn(|HandlerArgs { params, .. }, res| {
use prettytable::*;
if let Some(format) = params.format {
return display_serializable(format, res);
}
let mut table = Table::new();
table.add_row(row![bc => "NAME", "KEY"]);
for (key, info) in res {
table.add_row(row![info.name, key]);
}
table.print_tty(false)?;
Ok(())
})
.with_about("List authorized keys")
.with_call_remote::<CliContext>(),
),
)
}
#[derive(Debug, Deserialize, Serialize, Parser, TS)]
#[serde(rename_all = "camelCase")]
pub struct AddKeyParams {
pub name: InternedString,
pub key: AnyVerifyingKey,
}
pub async fn add_key(
ctx: TunnelContext,
AddKeyParams { name, key }: AddKeyParams,
) -> Result<(), Error> {
ctx.db
.mutate(|db| {
db.as_auth_pubkeys_mut().mutate(|auth_pubkeys| {
auth_pubkeys.insert(key, SignerInfo { name });
Ok(())
})
})
.await
.result
}
#[derive(Debug, Deserialize, Serialize, Parser, TS)]
#[serde(rename_all = "camelCase")]
pub struct RemoveKeyParams {
pub key: AnyVerifyingKey,
}
pub async fn remove_key(
ctx: TunnelContext,
RemoveKeyParams { key }: RemoveKeyParams,
) -> Result<(), Error> {
ctx.db
.mutate(|db| {
db.as_auth_pubkeys_mut()
.mutate(|auth_pubkeys| Ok(auth_pubkeys.remove(&key)))
})
.await
.result?;
Ok(())
}
pub async fn list_keys(ctx: TunnelContext) -> Result<HashMap<AnyVerifyingKey, SignerInfo>, Error> {
ctx.db.peek().await.into_auth_pubkeys().de()
}
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
pub struct SetPasswordParams {
pub password: String,
}
pub async fn set_password_rpc(
ctx: TunnelContext,
SetPasswordParams { password }: SetPasswordParams,
) -> Result<(), Error> {
let pwhash = argon2::hash_encoded(
password.as_bytes(),
&rand::random::<[u8; 16]>(),
&argon2::Config::rfc9106_low_mem(),
)
.with_kind(ErrorKind::PasswordHashGeneration)?;
ctx.db
.mutate(|db| db.as_password_mut().ser(&Some(pwhash)))
.await
.result?;
Ok(())
}
pub async fn set_password_cli(
HandlerArgs {
context,
parent_method,
method,
..
}: HandlerArgs<CliContext>,
) -> Result<(), Error> {
let password = rpassword::prompt_password("New Password: ")?;
let confirm = rpassword::prompt_password("Confirm Password: ")?;
if password != confirm {
return Err(Error::new(
eyre!("Passwords do not match"),
ErrorKind::InvalidRequest,
));
}
context
.call_remote::<TunnelContext>(
&parent_method.iter().chain(method.iter()).join("."),
to_value(&SetPasswordParams { password })?,
)
.await?;
println!("Password set successfully");
Ok(())
}
pub async fn reset_password(ctx: CliContext) -> Result<(), Error> {
println!("Generating a random password...");
let params = SetPasswordParams {
password: base32::encode(
base32::Alphabet::Rfc4648Lower { padding: false },
&rand::random::<[u8; 16]>(),
),
};
ctx.call_remote::<TunnelContext>("auth.set-password", to_value(&params)?)
.await?;
println!("Your new password is:");
println!("{}", params.password);
Ok(())
}

View File

@@ -0,0 +1,12 @@
# StartTunnel config for {name}
[Interface]
Address = {addr}
PrivateKey = {privkey}
[Peer]
PublicKey = {server_pubkey}
PresharedKey = {psk}
AllowedIPs = {subnet}
Endpoint = {server_addr}
PersistentKeepalive = 25

356
core/src/tunnel/context.rs Normal file
View File

@@ -0,0 +1,356 @@
use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, SocketAddr, SocketAddrV4};
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::sync::{Arc, OnceLock};
use clap::Parser;
use cookie::{Cookie, Expiration, SameSite};
use http::HeaderMap;
use imbl::OrdMap;
use imbl_value::InternedString;
use include_dir::Dir;
use ipnet::Ipv4Net;
use patch_db::PatchDb;
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{CallRemote, Context, Empty, ParentHandler};
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast::Sender;
use tracing::instrument;
use url::Url;
use crate::GatewayId;
use crate::auth::Sessions;
use crate::context::config::ContextConfig;
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::{NetworkInterfaceInfo, NetworkInterfaceType};
use crate::middleware::auth::Auth;
use crate::middleware::auth::local::LocalAuthContext;
use crate::middleware::cors::Cors;
use crate::net::forward::{PortForwardController, add_iptables_rule};
use crate::net::static_server::{EMPTY_DIR, UiContext};
use crate::prelude::*;
use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations};
use crate::tunnel::TUNNEL_DEFAULT_LISTEN;
use crate::tunnel::api::tunnel_api;
use crate::tunnel::db::TunnelDatabase;
use crate::tunnel::wg::{WIREGUARD_INTERFACE_NAME, WgSubnetConfig};
use crate::util::collections::OrdMapIterMut;
use crate::util::io::read_file_to_string;
use crate::util::sync::{SyncMutex, Watch};
#[derive(Debug, Clone, Default, Deserialize, Serialize, Parser)]
#[serde(rename_all = "kebab-case")]
#[command(rename_all = "kebab-case")]
pub struct TunnelConfig {
#[arg(short = 'c', long = "config")]
pub config: Option<PathBuf>,
#[arg(short = 'l', long = "listen")]
pub tunnel_listen: Option<SocketAddr>,
#[arg(short = 'd', long = "datadir")]
pub datadir: Option<PathBuf>,
}
impl ContextConfig for TunnelConfig {
fn next(&mut self) -> Option<PathBuf> {
self.config.take()
}
fn merge_with(&mut self, other: Self) {
self.tunnel_listen = self.tunnel_listen.take().or(other.tunnel_listen);
self.datadir = self.datadir.take().or(other.datadir);
}
}
impl TunnelConfig {
pub fn load(mut self) -> Result<Self, Error> {
let path = self.next();
self.load_path_rec(path)?;
self.load_path_rec(Some("/etc/start-tunneld"))?;
Ok(self)
}
}
pub struct TunnelContextSeed {
pub listen: SocketAddr,
pub db: TypedPatchDb<TunnelDatabase>,
pub datadir: PathBuf,
pub rpc_continuations: RpcContinuations,
pub open_authed_continuations: OpenAuthedContinuations<Option<InternedString>>,
pub ephemeral_sessions: SyncMutex<Sessions>,
pub net_iface: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
pub forward: PortForwardController,
pub active_forwards: SyncMutex<BTreeMap<SocketAddrV4, Arc<()>>>,
pub shutdown: Sender<()>,
}
#[derive(Clone)]
pub struct TunnelContext(Arc<TunnelContextSeed>);
impl TunnelContext {
#[instrument(skip_all)]
pub async fn init(config: &TunnelConfig) -> Result<Self, Error> {
Self::init_auth_cookie().await?;
let (shutdown, _) = tokio::sync::broadcast::channel(1);
let datadir = config
.datadir
.as_deref()
.unwrap_or_else(|| Path::new("/var/lib/start-tunnel"))
.to_owned();
if tokio::fs::metadata(&datadir).await.is_err() {
tokio::fs::create_dir_all(&datadir).await?;
}
let db_path = datadir.join("tunnel.db");
let db = TypedPatchDb::<TunnelDatabase>::load_or_init(
PatchDb::open(&db_path).await?,
|| async {
let mut db = TunnelDatabase::default();
db.wg.subnets.0.insert(
Ipv4Net::new_assert([10, 59, rand::random(), 1].into(), 24),
WgSubnetConfig {
name: "Default Subnet".into(),
..Default::default()
},
);
Ok(db)
},
)
.await?;
let listen = config.tunnel_listen.unwrap_or(TUNNEL_DEFAULT_LISTEN);
let ip_info = crate::net::utils::load_ip_info().await?;
let net_iface = db
.mutate(|db| {
db.as_gateways_mut().mutate(|g| {
for (_, v) in OrdMapIterMut::from(&mut *g) {
v.ip_info = None;
}
for (id, info) in ip_info {
if id.as_str() != WIREGUARD_INTERFACE_NAME {
g.entry(id).or_default().ip_info = Some(Arc::new(info));
}
}
Ok(g.clone())
})
})
.await
.result?;
let net_iface = Watch::new(net_iface);
let forward = PortForwardController::new();
add_iptables_rule(
false,
false,
&[
"FORWARD",
"-i",
WIREGUARD_INTERFACE_NAME,
"-m",
"state",
"--state",
"NEW",
"-j",
"ACCEPT",
],
)
.await?;
let peek = db.peek().await;
peek.as_wg().de()?.sync().await?;
for iface in net_iface.peek(|i| {
i.iter()
.filter(|(_, info)| {
info.ip_info.as_ref().map_or(false, |i| {
i.device_type != Some(NetworkInterfaceType::Loopback)
})
})
.map(|(name, _)| name)
.filter(|id| id.as_str() != WIREGUARD_INTERFACE_NAME)
.cloned()
.collect::<Vec<_>>()
}) {
for subnet in peek.as_wg().as_subnets().keys()? {
add_iptables_rule(
true,
false,
&[
"POSTROUTING",
"-s",
&subnet.trunc().to_string(),
"-o",
iface.as_str(),
"-j",
"MASQUERADE",
],
)
.await?;
}
}
let mut active_forwards = BTreeMap::new();
for (from, to) in peek.as_port_forwards().de()?.0 {
let prefix = net_iface
.peek(|i| {
i.iter()
.find_map(|(_, i)| {
i.ip_info.as_ref().and_then(|i| {
i.subnets
.iter()
.find(|s| s.contains(&IpAddr::from(*to.ip())))
})
})
.cloned()
})
.map(|s| s.prefix_len())
.unwrap_or(32);
active_forwards.insert(from, forward.add_forward(from, to, prefix).await?);
}
Ok(Self(Arc::new(TunnelContextSeed {
listen,
db,
datadir,
rpc_continuations: RpcContinuations::new(),
open_authed_continuations: OpenAuthedContinuations::new(),
ephemeral_sessions: SyncMutex::new(Sessions::new()),
net_iface,
forward,
active_forwards: SyncMutex::new(active_forwards),
shutdown,
})))
}
pub async fn gc_forwards(&self, keep: &BTreeSet<SocketAddrV4>) -> Result<(), Error> {
self.active_forwards
.mutate(|pf| pf.retain(|k, _| keep.contains(k)));
self.forward.gc().await
}
}
impl AsRef<RpcContinuations> for TunnelContext {
fn as_ref(&self) -> &RpcContinuations {
&self.rpc_continuations
}
}
impl AsRef<OpenAuthedContinuations<Option<InternedString>>> for TunnelContext {
fn as_ref(&self) -> &OpenAuthedContinuations<Option<InternedString>> {
&self.open_authed_continuations
}
}
impl Context for TunnelContext {}
impl Deref for TunnelContext {
type Target = TunnelContextSeed;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
#[derive(Debug, Deserialize, Serialize, Parser)]
pub struct TunnelAddrParams {
pub tunnel: IpAddr,
}
impl CallRemote<TunnelContext> for CliContext {
async fn call_remote(
&self,
mut method: &str,
params: Value,
_: Empty,
) -> Result<Value, RpcError> {
let (tunnel_addr, addr_from_config) = if let Some(addr) = self.tunnel_addr {
(addr, true)
} else if let Some(addr) = self.tunnel_listen {
(addr, true)
} else {
(TUNNEL_DEFAULT_LISTEN, false)
};
let local =
if let Ok(local) = read_file_to_string(TunnelContext::LOCAL_AUTH_COOKIE_PATH).await {
self.cookie_store
.lock()
.unwrap()
.insert_raw(
&Cookie::build(("local", local))
.domain(&tunnel_addr.ip().to_string())
.expires(Expiration::Session)
.same_site(SameSite::Strict)
.build(),
&format!("http://{tunnel_addr}").parse()?,
)
.with_kind(crate::ErrorKind::Network)?;
true
} else {
false
};
let (url, sig_ctx) = if local && tunnel_addr.ip().is_loopback() {
(format!("http://{tunnel_addr}/rpc/v0").parse()?, None)
} else if addr_from_config {
(
format!("https://{tunnel_addr}/rpc/v0").parse()?,
Some(InternedString::from_display(&tunnel_addr.ip())),
)
} else {
return Err(Error::new(eyre!("`--tunnel` required"), ErrorKind::InvalidRequest).into());
};
method = method.strip_prefix("tunnel.").unwrap_or(method);
crate::middleware::auth::signature::call_remote(
self,
url,
HeaderMap::new(),
sig_ctx.as_deref(),
method,
params,
)
.await
}
}
#[derive(Debug, Deserialize, Serialize, Parser)]
pub struct TunnelUrlParams {
pub tunnel: Url,
}
impl CallRemote<TunnelContext, TunnelUrlParams> for RpcContext {
async fn call_remote(
&self,
mut method: &str,
params: Value,
TunnelUrlParams { tunnel }: TunnelUrlParams,
) -> Result<Value, RpcError> {
let url = tunnel.join("rpc/v0")?;
method = method.strip_prefix("tunnel.").unwrap_or(method);
let sig_ctx = url.host_str().map(InternedString::from_display);
crate::middleware::auth::signature::call_remote(
self,
url,
HeaderMap::new(),
sig_ctx.as_deref(),
method,
params,
)
.await
}
}
pub static TUNNEL_UI_CELL: OnceLock<Dir<'static>> = OnceLock::new();
impl UiContext for TunnelContext {
fn ui_dir() -> &'static Dir<'static> {
TUNNEL_UI_CELL.get().unwrap_or(&EMPTY_DIR)
}
fn api() -> ParentHandler<Self> {
tracing::info!("loading tunnel api...");
tunnel_api()
}
fn middleware(server: rpc_toolkit::Server<Self>) -> rpc_toolkit::HttpServer<Self> {
server.middleware(Cors::new()).middleware(
Auth::new()
.with_local_auth()
.with_signature_auth()
.with_session_auth(),
)
}
}

325
core/src/tunnel/db.rs Normal file
View File

@@ -0,0 +1,325 @@
use std::collections::{BTreeMap, BTreeSet};
use std::net::SocketAddrV4;
use std::path::PathBuf;
use std::time::Duration;
use axum::extract::ws;
use clap::Parser;
use imbl::{HashMap, OrdMap};
use imbl_value::InternedString;
use itertools::Itertools;
use patch_db::Dump;
use patch_db::json_ptr::{JsonPointer, ROOT};
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use tracing::instrument;
use ts_rs::TS;
use crate::GatewayId;
use crate::auth::Sessions;
use crate::context::CliContext;
use crate::db::model::public::NetworkInterfaceInfo;
use crate::prelude::*;
use crate::rpc_continuations::{Guid, RpcContinuation};
use crate::sign::AnyVerifyingKey;
use crate::tunnel::auth::SignerInfo;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::web::WebserverInfo;
use crate::tunnel::wg::WgServer;
use crate::util::serde::{HandlerExtSerde, apply_expr};
#[derive(Default, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct TunnelDatabase {
pub webserver: WebserverInfo,
pub sessions: Sessions,
pub password: Option<String>,
#[ts(as = "std::collections::HashMap::<AnyVerifyingKey, SignerInfo>")]
pub auth_pubkeys: HashMap<AnyVerifyingKey, SignerInfo>,
#[ts(as = "std::collections::BTreeMap::<AnyVerifyingKey, SignerInfo>")]
pub gateways: OrdMap<GatewayId, NetworkInterfaceInfo>,
pub wg: WgServer,
pub port_forwards: PortForwards,
}
impl Model<TunnelDatabase> {
pub fn gc_forwards(&mut self) -> Result<BTreeSet<SocketAddrV4>, Error> {
let mut keep_sources = BTreeSet::new();
let mut keep_targets = BTreeSet::new();
for (_, cfg) in self.as_wg().as_subnets().as_entries()? {
keep_targets.extend(cfg.as_clients().keys()?);
}
self.as_port_forwards_mut().mutate(|pf| {
Ok(pf.0.retain(|k, v| {
if keep_targets.contains(v.ip()) {
keep_sources.insert(*k);
true
} else {
false
}
}))
})?;
Ok(keep_sources)
}
}
#[test]
fn export_bindings_tunnel_db() {
TunnelDatabase::export_all_to("bindings/tunnel").unwrap();
}
#[derive(Clone, Debug, Default, Deserialize, Serialize, TS)]
pub struct PortForwards(pub BTreeMap<SocketAddrV4, SocketAddrV4>);
impl Map for PortForwards {
type Key = SocketAddrV4;
type Value = SocketAddrV4;
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
Self::key_string(key)
}
fn key_string(key: &Self::Key) -> Result<InternedString, Error> {
Ok(InternedString::from_display(key))
}
}
pub fn db_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand(
"dump",
from_fn_async(cli_dump)
.with_display_serializable()
.with_about("Filter/query db to display tables and records"),
)
.subcommand(
"dump",
from_fn_async(dump)
.with_metadata("admin", Value::Bool(true))
.no_cli(),
)
.subcommand(
"subscribe",
from_fn_async(subscribe)
.with_metadata("get_session", Value::Bool(true))
.no_cli(),
)
.subcommand(
"apply",
from_fn_async(cli_apply)
.no_display()
.with_about("Update a db record"),
)
.subcommand(
"apply",
from_fn_async(apply)
.with_metadata("admin", Value::Bool(true))
.no_cli(),
)
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
#[command(rename_all = "kebab-case")]
pub struct CliDumpParams {
#[arg(long = "pointer", short = 'p')]
pointer: Option<JsonPointer>,
path: Option<PathBuf>,
}
#[instrument(skip_all)]
async fn cli_dump(
HandlerArgs {
context,
parent_method,
method,
params: CliDumpParams { pointer, path },
..
}: HandlerArgs<CliContext, CliDumpParams>,
) -> Result<Dump, RpcError> {
let dump = if let Some(path) = path {
PatchDb::open(path).await?.dump(&ROOT).await
} else {
let method = parent_method.into_iter().chain(method).join(".");
from_value::<Dump>(
context
.call_remote::<TunnelContext>(&method, imbl_value::json!({ "pointer": pointer }))
.await?,
)?
};
Ok(dump)
}
#[derive(Deserialize, Serialize, Parser, TS)]
#[serde(rename_all = "camelCase")]
#[command(rename_all = "kebab-case")]
pub struct DumpParams {
#[arg(long = "pointer", short = 'p')]
#[ts(type = "string | null")]
pointer: Option<JsonPointer>,
}
pub async fn dump(ctx: TunnelContext, DumpParams { pointer }: DumpParams) -> Result<Dump, Error> {
Ok(ctx
.db
.dump(&pointer.as_ref().map_or(ROOT, |p| p.borrowed()))
.await)
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
#[command(rename_all = "kebab-case")]
pub struct CliApplyParams {
expr: String,
path: Option<PathBuf>,
}
#[instrument(skip_all)]
async fn cli_apply(
HandlerArgs {
context,
parent_method,
method,
params: CliApplyParams { expr, path },
..
}: HandlerArgs<CliContext, CliApplyParams>,
) -> Result<(), RpcError> {
if let Some(path) = path {
PatchDb::open(path)
.await?
.apply_function(|db| {
let res = apply_expr(
serde_json::to_value(patch_db::Value::from(db))
.with_kind(ErrorKind::Deserialization)?
.into(),
&expr,
)?;
Ok::<_, Error>((
to_value(
&serde_json::from_value::<TunnelDatabase>(res.clone().into()).with_ctx(
|_| {
(
crate::ErrorKind::Deserialization,
"result does not match database model",
)
},
)?,
)?,
(),
))
})
.await
.result?;
} else {
let method = parent_method.into_iter().chain(method).join(".");
context
.call_remote::<TunnelContext>(&method, imbl_value::json!({ "expr": expr }))
.await?;
}
Ok(())
}
#[derive(Deserialize, Serialize, Parser, TS)]
#[serde(rename_all = "camelCase")]
#[command(rename_all = "kebab-case")]
pub struct ApplyParams {
expr: String,
path: Option<PathBuf>,
}
pub async fn apply(ctx: TunnelContext, ApplyParams { expr, .. }: ApplyParams) -> Result<(), Error> {
ctx.db
.mutate(|db| {
let res = apply_expr(
serde_json::to_value(patch_db::Value::from(db.clone()))
.with_kind(ErrorKind::Deserialization)?
.into(),
&expr,
)?;
db.ser(
&serde_json::from_value::<TunnelDatabase>(res.clone().into()).with_ctx(|_| {
(
crate::ErrorKind::Deserialization,
"result does not match database model",
)
})?,
)
})
.await
.result
}
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct SubscribeParams {
#[ts(type = "string | null")]
pointer: Option<JsonPointer>,
#[ts(skip)]
#[serde(rename = "__Auth_session")]
session: Option<InternedString>,
}
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct SubscribeRes {
#[ts(type = "{ id: number; value: unknown }")]
pub dump: Dump,
pub guid: Guid,
}
pub async fn subscribe(
ctx: TunnelContext,
SubscribeParams { pointer, session }: SubscribeParams,
) -> Result<SubscribeRes, Error> {
let (dump, mut sub) = ctx
.db
.dump_and_sub(pointer.unwrap_or_else(|| ROOT.to_owned()))
.await;
let guid = Guid::new();
ctx.rpc_continuations
.add(
guid.clone(),
RpcContinuation::ws_authed(
&ctx,
session,
|mut ws| async move {
if let Err(e) = async {
loop {
tokio::select! {
rev = sub.recv() => {
if let Some(rev) = rev {
ws.send(ws::Message::Text(
serde_json::to_string(&rev)
.with_kind(ErrorKind::Serialization)?
.into(),
))
.await
.with_kind(ErrorKind::Network)?;
} else {
return ws.normal_close("complete").await;
}
}
msg = ws.recv() => {
if msg.transpose().with_kind(ErrorKind::Network)?.is_none() {
return Ok(())
}
}
}
}
}
.await
{
tracing::error!("Error in db websocket: {e}");
tracing::debug!("{e:?}");
}
},
Duration::from_secs(30),
),
)
.await;
Ok(SubscribeRes { dump, guid })
}

23
core/src/tunnel/mod.rs Normal file
View File

@@ -0,0 +1,23 @@
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use axum::Router;
use crate::net::static_server::ui_router;
use crate::tunnel::context::TunnelContext;
pub mod api;
pub mod auth;
pub mod context;
pub mod db;
pub mod web;
pub mod wg;
pub const TUNNEL_DEFAULT_PORT: u16 = 5960;
pub const TUNNEL_DEFAULT_LISTEN: SocketAddr = SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(127, 0, 59, 60),
TUNNEL_DEFAULT_PORT,
));
pub fn tunnel_router(ctx: TunnelContext) -> Router {
ui_router(ctx)
}

View File

@@ -0,0 +1,4 @@
[Peer]
PublicKey = {pubkey}
PresharedKey = {psk}
AllowedIPs = {addr}/32

View File

@@ -0,0 +1,5 @@
[Interface]
Address = {subnets}
PrivateKey = {server_privkey}
ListenPort = {server_port}

683
core/src/tunnel/web.rs Normal file
View File

@@ -0,0 +1,683 @@
use std::collections::VecDeque;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use clap::Parser;
use imbl_value::{InternedString, json};
use itertools::Itertools;
use openssl::pkey::{PKey, Private};
use openssl::x509::X509;
use rpc_toolkit::{
Context, Empty, HandlerArgs, HandlerExt, ParentHandler, from_fn_async, from_fn_async_local,
};
use serde::{Deserialize, Serialize};
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::rustls::crypto::CryptoProvider;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio_rustls::rustls::server::ClientHello;
use ts_rs::TS;
use crate::context::CliContext;
use crate::hostname::Hostname;
use crate::net::ssl::{SANInfo, root_ca_start_time};
use crate::net::tls::TlsHandler;
use crate::net::web_server::Accept;
use crate::prelude::*;
use crate::tunnel::auth::SetPasswordParams;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::db::TunnelDatabase;
use crate::util::serde::{HandlerExtSerde, Pem, display_serializable};
use crate::util::tui::{choose, parse_as, prompt, prompt_multiline};
#[derive(Debug, Default, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct WebserverInfo {
pub enabled: bool,
pub listen: Option<SocketAddr>,
pub certificate: Option<TunnelCertData>,
}
#[derive(Debug, Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct TunnelCertData {
pub key: Pem<PKey<Private>>,
pub cert: Pem<Vec<X509>>,
}
#[derive(Clone)]
pub struct TunnelCertHandler {
pub db: TypedPatchDb<TunnelDatabase>,
pub crypto_provider: Arc<CryptoProvider>,
}
impl<'a, A> TlsHandler<'a, A> for TunnelCertHandler
where
A: Accept + 'a,
<A as Accept>::Metadata: Send + Sync,
{
async fn get_config(
&'a mut self,
_: &'a ClientHello<'a>,
_: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig> {
let cert_info = self
.db
.peek()
.await
.as_webserver()
.as_certificate()
.de()
.log_err()??;
let cert_chain: Vec<_> = cert_info
.cert
.0
.iter()
.map(|c| Ok::<_, Error>(CertificateDer::from(c.to_der()?)))
.collect::<Result<_, _>>()
.log_err()?;
let cert_key = cert_info.key.0.private_key_to_pkcs8().log_err()?;
let mut cfg = ServerConfig::builder_with_provider(self.crypto_provider.clone())
.with_safe_default_protocol_versions()
.log_err()?
.with_no_client_auth()
.with_single_cert(
cert_chain,
PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert_key)),
)
.log_err()?;
cfg.alpn_protocols
.extend([b"http/1.1".into(), b"h2".into()]);
Some(cfg)
}
}
pub fn web_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand(
"init",
from_fn_async_local(init_web)
.no_display()
.with_about("Initialize the webserver"),
)
.subcommand(
"set-listen",
from_fn_async(set_listen)
.no_display()
.with_about("Set the listen address for the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"get-listen",
from_fn_async(get_listen)
.with_display_serializable()
.with_about("Get the listen address for the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"get-available-ips",
from_fn_async(get_available_ips)
.with_display_serializable()
.with_about("Get available IP addresses to bind to")
.with_call_remote::<CliContext>(),
)
.subcommand(
"import-certificate",
from_fn_async(import_certificate_rpc).no_cli(),
)
.subcommand(
"import-certificate",
from_fn_async_local(import_certificate_cli)
.no_display()
.with_about("Import a certificate to use for the webserver"),
)
.subcommand(
"generate-certificate",
from_fn_async(generate_certificate)
.with_about("Generate a certificate to use for the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"get-certificate",
from_fn_async(get_certificate)
.with_display_serializable()
.with_custom_display_fn(|HandlerArgs { params, .. }, res| {
if let Some(format) = params.format {
return display_serializable(format, res);
}
if let Some(res) = res {
println!("{res}");
}
Ok(())
})
.with_about("Get the certificate for the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"enable",
from_fn_async(enable_web)
.with_about("Enable the webserver")
.no_display()
.with_call_remote::<CliContext>(),
)
.subcommand(
"disable",
from_fn_async(disable_web)
.no_display()
.with_about("Disable the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"reset",
from_fn_async(reset_web)
.no_display()
.with_about("Reset the webserver")
.with_call_remote::<CliContext>(),
)
}
pub async fn import_certificate_rpc(
ctx: TunnelContext,
cert_data: TunnelCertData,
) -> Result<(), Error> {
ctx.db
.mutate(|db| {
db.as_webserver_mut()
.as_certificate_mut()
.ser(&Some(cert_data))
})
.await
.result?;
Ok(())
}
pub async fn import_certificate_cli(
HandlerArgs {
context,
parent_method,
method,
..
}: HandlerArgs<CliContext>,
) -> Result<(), Error> {
let mut key_string = String::new();
let key: Pem<PKey<Private>> =
prompt_multiline("Please paste in your PEM encoded private key: ", |line| {
key_string.push_str(&line);
key_string.push_str("\n");
if line.trim().starts_with("-----END") {
return key_string.parse().map(Some).map_err(|e| {
key_string.truncate(0);
e
});
}
Ok(None)
})
.await?;
let mut chain = Vec::<X509>::new();
let mut cert_string = String::new();
prompt_multiline(
concat!(
"Please paste in your PEM encoded certificate",
" (or certificate chain):"
),
|line| {
cert_string.push_str(&line);
cert_string.push_str("\n");
if line.trim().starts_with("-----END") {
let cert = cert_string.parse::<Pem<X509>>();
cert_string.truncate(0);
let cert = cert?;
let pubkey = cert.0.public_key()?;
if chain.is_empty() {
if !key.public_eq(&pubkey) {
return Err(Error::new(
eyre!("Certificate does not match key!"),
ErrorKind::InvalidSignature,
));
}
}
if let Some(prev) = chain.last() {
if !prev.verify(&pubkey)? {
return Err(Error::new(
eyre!(concat!(
"Invalid Fullchain: ",
"Previous cert was not signed by this certificate's key"
)),
ErrorKind::InvalidSignature,
));
}
}
let is_root = cert.0.verify(&pubkey)?;
chain.push(cert.0);
if is_root { Ok(Some(())) } else { Ok(None) }
} else {
Ok(None)
}
},
)
.await?;
context
.call_remote::<TunnelContext>(
&parent_method.iter().chain(method.iter()).join("."),
to_value(&TunnelCertData {
key,
cert: Pem(chain),
})?,
)
.await?;
Ok(())
}
#[derive(Debug, Deserialize, Serialize, Parser)]
pub struct GenerateCertParams {
#[arg(help = "Subject Alternative Name(s)")]
pub subject: Vec<InternedString>,
}
pub async fn generate_certificate(
ctx: TunnelContext,
GenerateCertParams { subject }: GenerateCertParams,
) -> Result<Pem<Vec<X509>>, Error> {
let saninfo = SANInfo::new(&subject.into_iter().collect());
let root_key = crate::net::ssl::gen_nistp256()?;
let root_cert = crate::net::ssl::make_root_cert(
&root_key,
&Hostname("start-tunnel".into()),
root_ca_start_time().await,
)?;
let int_key = crate::net::ssl::gen_nistp256()?;
let int_cert = crate::net::ssl::make_int_cert((&root_key, &root_cert), &int_key)?;
let key = crate::net::ssl::gen_nistp256()?;
let cert = crate::net::ssl::make_leaf_cert((&int_key, &int_cert), (&key, &saninfo))?;
let chain = Pem(vec![cert, int_cert, root_cert]);
ctx.db
.mutate(|db| {
db.as_webserver_mut()
.as_certificate_mut()
.ser(&Some(TunnelCertData {
key: Pem(key),
cert: chain.clone(),
}))
})
.await
.result?;
Ok(chain)
}
pub async fn get_certificate(ctx: TunnelContext) -> Result<Option<Pem<Vec<X509>>>, Error> {
ctx.db
.peek()
.await
.as_webserver()
.as_certificate()
.de()?
.map(|cert_data| Ok(cert_data.cert))
.transpose()
}
#[derive(Debug, Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct SetListenParams {
pub listen: SocketAddr,
}
pub async fn set_listen(
ctx: TunnelContext,
SetListenParams { listen }: SetListenParams,
) -> Result<(), Error> {
// Validate that the address is available to bind
tokio::net::TcpListener::bind(listen)
.await
.with_kind(ErrorKind::Network)
.with_ctx(|_| {
(
ErrorKind::Network,
format!("{} is not available to bind to", listen),
)
})?;
ctx.db
.mutate(|db| {
db.as_webserver_mut().as_listen_mut().ser(&Some(listen))?;
Ok(())
})
.await
.result
}
pub async fn get_listen(ctx: TunnelContext) -> Result<Option<SocketAddr>, Error> {
ctx.db.peek().await.as_webserver().as_listen().de()
}
pub async fn get_available_ips(ctx: TunnelContext) -> Result<Vec<IpAddr>, Error> {
let ips = ctx.net_iface.peek(|interfaces| {
interfaces
.values()
.flat_map(|info| {
info.ip_info
.iter()
.flat_map(|ip_info| ip_info.subnets.iter().map(|subnet| subnet.addr()))
})
.collect::<Vec<IpAddr>>()
});
Ok(ips)
}
pub async fn enable_web(ctx: TunnelContext) -> Result<(), Error> {
ctx.db
.mutate(|db| {
if db.as_webserver().as_listen().transpose_ref().is_none() {
return Err(Error::new(
eyre!("Listen is not set"),
ErrorKind::ParseNetAddress,
));
}
if db.as_webserver().as_certificate().transpose_ref().is_none() {
return Err(Error::new(
eyre!("Certificate is not set"),
ErrorKind::OpenSsl,
));
}
if db.as_password().transpose_ref().is_none() {
return Err(Error::new(
eyre!("Password is not set"),
ErrorKind::Authorization,
));
};
db.as_webserver_mut().as_enabled_mut().ser(&true)
})
.await
.result
}
pub async fn disable_web(ctx: TunnelContext) -> Result<(), Error> {
ctx.db
.mutate(|db| db.as_webserver_mut().as_enabled_mut().ser(&false))
.await
.result
}
pub async fn reset_web(ctx: TunnelContext) -> Result<(), Error> {
ctx.db
.mutate(|db| {
db.as_webserver_mut().as_enabled_mut().ser(&false)?;
db.as_webserver_mut().as_listen_mut().ser(&None)?;
db.as_webserver_mut().as_certificate_mut().ser(&None)?;
db.as_password_mut().ser(&None)?;
Ok(())
})
.await
.result
}
fn is_valid_domain(domain: &str) -> bool {
if domain.is_empty() || domain.len() > 253 || domain.starts_with('.') || domain.ends_with('.') {
return false;
}
let labels: Vec<&str> = domain.split('.').collect();
for label in labels {
if label.is_empty() || label.len() > 63 {
return false;
}
if !label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') {
return false;
}
if label.chars().next().map_or(true, |c| c == '-')
|| label.chars().next_back().map_or(true, |c| c == '-')
{
return false;
}
}
true
}
pub async fn init_web(ctx: CliContext) -> Result<(), Error> {
let mut password = None;
loop {
match ctx
.call_remote::<TunnelContext>("web.enable", json!({}))
.await
{
Ok(_) => {
let listen = from_value::<SocketAddr>(
ctx.call_remote::<TunnelContext>("web.get-listen", json!({}))
.await?,
)?;
println!("✅ Success! ✅");
println!(
"The webserver is running. Below is your URL{} and Root Certificate Authority (Root CA).",
if password.is_some() {
", password,"
} else {
""
}
);
println!();
println!("🌐 URL");
println!("https://{listen}");
if listen.ip().is_unspecified() {
println!(concat!(
"Note: this is the unspecified address. ",
"This means you can use any IP address available to this device to connect. ",
"Using the above address as-is will only work from this device."
));
} else if listen.ip().is_loopback() {
println!(concat!(
"Note: this is a loopback address. ",
"This is only recommended if you are planning to run a proxy in front of the web ui. ",
"Using the above address as-is will only work from this device."
));
}
println!();
if let Some(password) = password {
println!("🔒 Password");
println!("{password}");
println!();
println!(concat!(
"If you lose or forget your password, you can reset it using the following command: ",
"start-tunnel auth reset-password"
));
} else {
println!(concat!(
"Your password was set up previously. ",
"If you don't remember it, you can reset it using the command: ",
"start-tunnel auth reset-password"
));
}
println!();
let cert = from_value::<Pem<Vec<X509>>>(
ctx.call_remote::<TunnelContext>("web.get-certificate", json!({}))
.await?,
)?
.0
.pop()
.map(Pem)
.or_not_found("certificate in chain")?;
println!("📝 Root CA:");
print!("{cert}");
println!(concat!(
"To trust your StartTunnel Root CA (above):\n",
" 1. Copy the Root CA ",
"(starting with -----BEGIN CERTIFICATE----- and ending with -----END CERTIFICATE-----).\n",
" 2. Open a text editor: \n",
" - Linux: gedit, nano, or any editor\n",
" - Mac: TextEdit\n",
" - Windows: Notepad\n",
" 3. Paste the contents of your Root CA.\n",
" 4. Save the file with a `.crt` extension ",
"(e.g. `start-tunnel.crt`) (make sure it saves as plain text, not rich text).\n",
" 5. Follow instructions to trust you StartTunnel Root CA: ",
"https://staging.docs.start9.com/user-manual/trust-ca.html#2-trust-your-servers-root-ca."
));
return Ok(());
}
Err(e) if e.kind == ErrorKind::ParseNetAddress => {
println!("Select the IP address at which to host the web interface:");
let mut suggested_addrs = from_value::<Vec<IpAddr>>(
ctx.call_remote::<TunnelContext>("web.get-available-ips", json!({}))
.await?,
)?;
suggested_addrs.retain(|ip| match ip {
IpAddr::V4(a) => !a.is_loopback() && !a.is_private(),
IpAddr::V6(a) => !a.is_loopback() && !a.is_unicast_link_local(),
});
let ip = if suggested_addrs.len() == 1 {
suggested_addrs[0]
} else if suggested_addrs.is_empty() {
prompt("Listen Address: ", parse_as::<IpAddr>("IP Address"), None).await?
} else if suggested_addrs.len() > 16 {
prompt(
&format!("Listen Address [{}]: ", suggested_addrs[0]),
parse_as::<IpAddr>("IP Address"),
Some(suggested_addrs[0]),
)
.await?
} else {
*choose("Listen Address:", &suggested_addrs).await?
};
println!(concat!(
"Enter the port at which to host the web interface. ",
"The recommended default is 8443. ",
"If you change the default, choose an uncommon port to avoid conflicts: "
));
let port = prompt("Port [8443]: ", parse_as::<u16>("port"), Some(8443)).await?;
let listen = SocketAddr::new(ip, port);
ctx.call_remote::<TunnelContext>(
"web.set-listen",
to_value(&SetListenParams { listen })?,
)
.await?;
println!();
}
Err(e) if e.kind == ErrorKind::OpenSsl => {
enum Choice {
Generate,
Provide,
}
impl std::fmt::Display for Choice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Generate => write!(f, "Generate"),
Self::Provide => write!(f, "Provide"),
}
}
}
let options = vec![Choice::Generate, Choice::Provide];
let choice = choose(
concat!(
"Select whether to generate an SSL certificate ",
"or provide your own certificate (and key):"
),
&options,
)
.await?;
match choice {
Choice::Generate => {
let listen = from_value::<Option<SocketAddr>>(
ctx.call_remote::<TunnelContext>("web.get-listen", json!({}))
.await?,
)?
.filter(|a| !a.ip().is_unspecified());
let san_info = if let Some(listen) = listen {
vec![InternedString::from_display(&listen.ip())]
} else {
println!(
"List all IP addresses and domains for which to sign the certificate, separated by commas."
);
prompt(
"Subject Alternative Name(s): ",
|s| {
s.split(",")
.map(|s| {
let s = s.trim();
if let Ok(ip) = s.parse::<IpAddr>() {
Ok(InternedString::from_display(&ip))
} else if is_valid_domain(s) {
Ok(s.into())
} else {
Err(format!(
"{s} is not a valid ip address or domain"
))
}
})
.collect()
},
listen.map(|l| vec![InternedString::from_display(&l.ip())]),
)
.await?
};
ctx.call_remote::<TunnelContext>(
"web.generate-certificate",
to_value(&GenerateCertParams { subject: san_info })?,
)
.await?;
}
Choice::Provide => {
import_certificate_cli(HandlerArgs {
context: ctx.clone(),
parent_method: vec!["web", "import-certificate"].into(),
method: VecDeque::new(),
params: Empty {},
inherited_params: Empty {},
raw_params: json!({}),
})
.await?;
}
}
println!();
}
Err(e) if e.kind == ErrorKind::Authorization => {
println!("Generating a random password...");
let params = SetPasswordParams {
password: base32::encode(
base32::Alphabet::Rfc4648Lower { padding: false },
&rand::random::<[u8; 16]>(),
),
};
ctx.call_remote::<TunnelContext>("auth.set-password", to_value(&params)?)
.await?;
password = Some(params.password);
println!();
}
Err(e) => return Err(e.into()),
}
}
}

255
core/src/tunnel/wg.rs Normal file
View File

@@ -0,0 +1,255 @@
use std::collections::BTreeMap;
use std::net::{Ipv4Addr, SocketAddr};
use imbl_value::InternedString;
use ipnet::Ipv4Net;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use ts_rs::TS;
use x25519_dalek::{PublicKey, StaticSecret};
use crate::prelude::*;
use crate::util::Invoke;
use crate::util::io::write_file_atomic;
use crate::util::serde::Base64;
pub const WIREGUARD_INTERFACE_NAME: &str = "wg-start-tunnel";
#[derive(Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct WgServer {
pub port: u16,
pub key: Base64<WgKey>,
pub subnets: WgSubnetMap,
}
impl Default for WgServer {
fn default() -> Self {
Self {
port: 51820,
key: Base64(WgKey::generate()),
subnets: WgSubnetMap::default(),
}
}
}
impl WgServer {
pub fn server_config<'a>(&'a self) -> ServerConfig<'a> {
ServerConfig(self)
}
pub async fn sync(&self) -> Result<(), Error> {
Command::new("wg-quick")
.arg("down")
.arg(WIREGUARD_INTERFACE_NAME)
.invoke(ErrorKind::Network)
.await
.or_else(|e| {
let msg = e.source.to_string();
if msg.contains("does not exist") || msg.contains("is not a WireGuard interface") {
Ok(Vec::new())
} else {
Err(e)
}
})?;
write_file_atomic(
const_format::formatcp!("/etc/wireguard/{WIREGUARD_INTERFACE_NAME}.conf"),
self.server_config().to_string().as_bytes(),
)
.await?;
Command::new("wg-quick")
.arg("up")
.arg(WIREGUARD_INTERFACE_NAME)
.invoke(ErrorKind::Network)
.await?;
Ok(())
}
}
#[derive(Default, Deserialize, Serialize, TS)]
pub struct WgSubnetMap(
#[ts(as = "BTreeMap::<String, WgSubnetConfig>")] pub BTreeMap<Ipv4Net, WgSubnetConfig>,
);
impl Map for WgSubnetMap {
type Key = Ipv4Net;
type Value = WgSubnetConfig;
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
Self::key_string(key)
}
fn key_string(key: &Self::Key) -> Result<InternedString, Error> {
Ok(InternedString::from_display(key))
}
}
#[derive(Default, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct WgSubnetConfig {
pub name: InternedString,
pub clients: WgSubnetClients,
}
impl WgSubnetConfig {
pub fn new(name: InternedString) -> Self {
Self {
name,
..Self::default()
}
}
}
#[derive(Default, Deserialize, Serialize, TS)]
pub struct WgSubnetClients(pub BTreeMap<Ipv4Addr, WgConfig>);
impl Map for WgSubnetClients {
type Key = Ipv4Addr;
type Value = WgConfig;
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
Self::key_string(key)
}
fn key_string(key: &Self::Key) -> Result<InternedString, Error> {
Ok(InternedString::from_display(key))
}
}
#[derive(Clone)]
pub struct WgKey(StaticSecret);
impl WgKey {
pub fn generate() -> Self {
Self(StaticSecret::random_from_rng(
ssh_key::rand_core::OsRng::default(),
))
}
}
impl AsRef<[u8]> for WgKey {
fn as_ref(&self) -> &[u8] {
self.0.as_bytes()
}
}
impl TryFrom<Vec<u8>> for WgKey {
type Error = Error;
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
Ok(Self(
<[u8; 32]>::try_from(value)
.map_err(|_| Error::new(eyre!("invalid key length"), ErrorKind::Deserialization))?
.into(),
))
}
}
impl std::ops::Deref for WgKey {
type Target = StaticSecret;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Base64<WgKey> {
pub fn verifying_key(&self) -> Base64<PublicKey> {
Base64((&*self.0).into())
}
}
#[derive(Clone, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct WgConfig {
pub name: InternedString,
pub key: Base64<WgKey>,
pub psk: Base64<[u8; 32]>,
}
impl WgConfig {
pub fn generate(name: InternedString) -> Self {
Self {
name,
key: Base64(WgKey::generate()),
psk: Base64(rand::random()),
}
}
pub fn server_peer_config<'a>(&'a self, addr: Ipv4Addr) -> ServerPeerConfig<'a> {
ServerPeerConfig {
client_config: self,
client_addr: addr,
}
}
pub fn client_config(
self,
addr: Ipv4Addr,
subnet: Ipv4Net,
server_pubkey: Base64<PublicKey>,
server_addr: SocketAddr,
) -> ClientConfig {
ClientConfig {
client_config: self,
client_addr: addr,
subnet,
server_pubkey,
server_addr,
}
}
}
pub struct ServerPeerConfig<'a> {
client_config: &'a WgConfig,
client_addr: Ipv4Addr,
}
impl<'a> std::fmt::Display for ServerPeerConfig<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
include_str!("./server-peer.conf.template"),
pubkey = self.client_config.key.verifying_key().to_padded_string(),
psk = self.client_config.psk.to_padded_string(),
addr = self.client_addr,
)
}
}
fn deserialize_verifying_key<'de, D>(deserializer: D) -> Result<Base64<PublicKey>, D::Error>
where
D: serde::Deserializer<'de>,
{
Base64::<Vec<u8>>::deserialize(deserializer).and_then(|b| {
Ok(Base64(PublicKey::from(<[u8; 32]>::try_from(b.0).map_err(
|e: Vec<u8>| serde::de::Error::invalid_length(e.len(), &"a 32 byte base64 string"),
)?)))
})
}
#[derive(Clone, Serialize, Deserialize)]
pub struct ClientConfig {
client_config: WgConfig,
client_addr: Ipv4Addr,
subnet: Ipv4Net,
#[serde(deserialize_with = "deserialize_verifying_key")]
server_pubkey: Base64<PublicKey>,
server_addr: SocketAddr,
}
impl std::fmt::Display for ClientConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
include_str!("./client.conf.template"),
name = self.client_config.name,
privkey = self.client_config.key.to_padded_string(),
psk = self.client_config.psk.to_padded_string(),
addr = Ipv4Net::new_assert(self.client_addr, self.subnet.prefix_len()),
subnet = self.subnet.trunc(),
server_pubkey = self.server_pubkey.to_padded_string(),
server_addr = self.server_addr,
)
}
}
pub struct ServerConfig<'a>(&'a WgServer);
impl<'a> std::fmt::Display for ServerConfig<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self(server) = *self;
write!(
f,
include_str!("./server.conf.template"),
subnets = server.subnets.0.keys().join(", "),
server_port = server.port,
server_privkey = server.key.to_padded_string(),
)?;
for (addr, peer) in server.subnets.0.values().flat_map(|s| &s.clients.0) {
write!(f, "{}", peer.server_peer_config(*addr))?;
}
Ok(())
}
}