Feature/start tunnel (#3037)

* fix live-build resolv.conf

* improved debuggability

* wip: start-tunnel

* fixes for trixie and tor

* non-free-firmware on trixie

* wip

* web server WIP

* wip: tls refactor

* FE patchdb, mocks, and most endpoints

* fix editing records and patch mocks

* refactor complete

* finish api

* build and formatter update

* minor change toi viewing addresses and fix build

* fixes

* more providers

* endpoint for getting config

* fix tests

* api fixes

* wip: separate port forward controller into parts

* simplify iptables rules

* bump sdk

* misc fixes

* predict next subnet and ip, use wan ips, and form validation

* refactor: break big components apart and address todos (#3043)

* refactor: break big components apart and address todos

* starttunnel readme, fix pf mocks, fix adding tor domain in startos

---------

Co-authored-by: Matt Hill <mattnine@protonmail.com>

* better tui

* tui tweaks

* fix: address comments

* better regex for subnet

* fixes

* better validation

* handle rpc errors

* build fixes

* fix: address comments (#3044)

* fix: address comments

* fix unread notification mocks

* fix row click for notification

---------

Co-authored-by: Matt Hill <mattnine@protonmail.com>

* fix raspi build

* fix build

* fix build

* fix build

* fix build

* try to fix build

* fix tests

* fix tests

* fix rsync tests

* delete useless effectful test

---------

Co-authored-by: Matt Hill <mattnine@protonmail.com>
Co-authored-by: Alex Inkin <alexander@inkin.ru>
This commit is contained in:
Aiden McClelland
2025-11-07 03:12:05 -07:00
committed by GitHub
parent 1ea525feaa
commit 68f401bfa3
229 changed files with 17255 additions and 10553 deletions

View File

@@ -4,18 +4,18 @@ description = "The core of StartOS"
documentation = "https://docs.rs/start-os"
edition = "2024"
keywords = [
"self-hosted",
"raspberry-pi",
"privacy",
"bitcoin",
"full-node",
"lightning",
"privacy",
"raspberry-pi",
"self-hosted",
]
license = "MIT"
name = "start-os"
readme = "README.md"
repository = "https://github.com/Start9Labs/start-os"
version = "0.4.0-alpha.11" # VERSION_BUMP
license = "MIT"
version = "0.4.0-alpha.12" # VERSION_BUMP
[lib]
name = "startos"
@@ -42,48 +42,61 @@ name = "tunnelbox"
path = "src/main.rs"
[features]
cli = ["cli-startd", "cli-registry", "cli-tunnel"]
arti = [
"arti-client",
"models/arti",
"safelog",
"tor-cell",
"tor-hscrypto",
"tor-hsservice",
"tor-keymgr",
"tor-llcrypto",
"tor-proto",
"tor-rtcompat",
]
cli = ["cli-registry", "cli-startd", "cli-tunnel"]
cli-container = ["procfs", "pty-process"]
cli-registry = []
cli-startd = []
cli-tunnel = []
default = ["cli", "startd", "registry", "cli-container", "tunnel"]
console = ["console-subscriber", "tokio/tracing"]
default = ["cli", "cli-container", "registry", "startd", "tunnel"]
dev = ["backtrace-on-stack-overflow"]
docker = []
registry = []
startd = ["mail-send"]
startd = []
test = []
tunnel = []
console = ["console-subscriber", "tokio/tracing"]
unstable = ["backtrace-on-stack-overflow"]
[dependencies]
aes = { version = "0.7.5", features = ["ctr"] }
arti-client = { version = "0.33", features = [
"compression",
"ephemeral-keystore",
"experimental-api",
"onion-service-client",
"onion-service-service",
"rustls",
"static",
"tokio",
"ephemeral-keystore",
"onion-service-client",
"onion-service-service",
], default-features = false, git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit" }
aes = { version = "0.7.5", features = ["ctr"] }
], default-features = false, git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit", optional = true }
async-acme = { version = "0.6.0", git = "https://github.com/dr-bonez/async-acme.git", features = [
"use_rustls",
"use_tokio",
] }
async-compression = { version = "0.4.4", features = [
"gzip",
async-compression = { version = "0.4.32", features = [
"brotli",
"gzip",
"tokio",
"zstd",
] }
async-stream = "0.3.5"
async-trait = "0.1.74"
aws-lc-sys = { version = "0.32", features = ["bindgen"] }
axum = { version = "0.8.4", features = ["ws"] }
barrage = "0.2.3"
backhand = "0.21.0"
backtrace-on-stack-overflow = { version = "0.3.0", optional = true }
barrage = "0.2.3"
base32 = "0.5.0"
base64 = "0.22.1"
base64ct = "1.6.0"
@@ -98,17 +111,19 @@ console-subscriber = { version = "0.4.1", optional = true }
const_format = "0.2.34"
cookie = "0.18.0"
cookie_store = "0.21.0"
curve25519-dalek = "4.1.3"
der = { version = "0.7.9", features = ["derive", "pem"] }
digest = "0.10.7"
divrem = "1.0.0"
dns-lookup = "2.1.0"
ed25519 = { version = "2.2.3", features = ["pkcs8", "pem", "alloc"] }
ed25519 = { version = "2.2.3", features = ["alloc", "pem", "pkcs8"] }
ed25519-dalek = { version = "2.2.0", features = [
"digest",
"hazmat",
"pkcs8",
"rand_core",
"serde",
"zeroize",
"rand_core",
"digest",
"pkcs8",
] }
ed25519-dalek-v1 = { package = "ed25519-dalek", version = "1" }
exver = { version = "0.2.0", git = "https://github.com/Start9Labs/exver-rs.git", features = [
@@ -125,20 +140,21 @@ hickory-server = "0.25.2"
hmac = "0.12.1"
http = "1.0.0"
http-body-util = "0.1"
hyper = { version = "1.5", features = ["server", "http1", "http2"] }
hyper = { version = "1.5", features = ["http1", "http2", "server"] }
hyper-util = { version = "0.1.10", features = [
"http1",
"http2",
"server",
"server-auto",
"server-graceful",
"service",
"http1",
"http2",
"tokio",
] }
id-pool = { version = "0.2.2", default-features = false, features = [
"serde",
"u16",
] }
iddqd = "0.3.14"
imbl = { version = "6", features = ["serde", "small-chunks"] }
imbl-value = { version = "0.4.3", features = ["ts-rs"] }
include_dir = { version = "0.7.3", features = ["metadata"] }
@@ -156,10 +172,20 @@ jsonpath_lib = { git = "https://github.com/Start9Labs/jsonpath.git" }
lazy_async_pool = "0.3.3"
lazy_format = "2.0"
lazy_static = "1.4.0"
lettre = { version = "0.11.18", default-features = false, features = [
"aws-lc-rs",
"builder",
"hostname",
"pool",
"rustls-platform-verifier",
"smtp-transport",
"tokio1-rustls",
] }
libc = "0.2.149"
log = "0.4.20"
mio = "1"
mbrman = "0.6.0"
miette = { version = "7.6.0", features = ["fancy"] }
mio = "1"
models = { version = "*", path = "../models" }
new_mime_guess = "4"
nix = { version = "0.30.1", features = [
@@ -173,8 +199,8 @@ nix = { version = "0.30.1", features = [
] }
nom = "8.0.0"
num = "0.4.1"
num_enum = "0.7.0"
num_cpus = "1.16.0"
num_enum = "0.7.0"
once_cell = "1.19.0"
openssh-keys = "0.6.2"
openssl = { version = "0.10.57", features = ["vendored"] }
@@ -191,22 +217,22 @@ proptest = "1.3.1"
proptest-derive = "0.5.0"
pty-process = { version = "0.5.1", optional = true }
qrcode = "0.14.1"
r3bl_tui = "0.7.6"
rand = "0.9.2"
regex = "1.10.2"
reqwest = { version = "0.12.4", features = ["stream", "json", "socks"] }
reqwest = { version = "0.12.4", features = ["json", "socks", "stream"] }
reqwest_cookie_store = "0.8.0"
rpassword = "7.2.0"
rpc-toolkit = { git = "https://github.com/Start9Labs/rpc-toolkit.git", branch = "master" }
rust-argon2 = "2.0.0"
rustyline-async = "0.4.1"
safelog = { version = "0.4.8", git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit" }
safelog = { version = "0.4.8", git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit", optional = true }
semver = { version = "1.0.20", features = ["serde"] }
serde = { version = "1.0", features = ["derive", "rc"] }
serde_cbor = { package = "ciborium", version = "0.2.1" }
serde_json = "1.0"
serde_toml = { package = "toml", version = "0.8.2" }
serde_urlencoded = "0.7"
serde_with = { version = "3.4.0", features = ["macros", "json"] }
serde_with = { version = "3.4.0", features = ["json", "macros"] }
serde_yaml = { package = "serde_yml", version = "0.0.12" }
sha-crypt = "0.5.0"
sha2 = "0.10.2"
@@ -214,39 +240,40 @@ shell-words = "1"
signal-hook = "0.3.17"
simple-logging = "2.0.2"
socket2 = { version = "0.6.0", features = ["all"] }
socks5-impl = { version = "0.7.2", features = ["server"] }
socks5-impl = { version = "0.7.2", features = ["client", "server"] }
sqlx = { version = "0.8.6", features = [
"runtime-tokio-rustls",
"postgres",
"runtime-tokio-rustls",
], default-features = false }
sscanf = "0.4.1"
ssh-key = { version = "0.6.2", features = ["ed25519"] }
tar = "0.4.40"
termion = "4.0.5"
thiserror = "2.0.12"
textwrap = "0.16.1"
thiserror = "2.0.12"
tokio = { version = "1.38.1", features = ["full"] }
tokio-rustls = "0.26.0"
tokio-stream = { version = "0.1.14", features = ["io-util", "sync", "net"] }
tokio-stream = { version = "0.1.14", features = ["io-util", "net", "sync"] }
tokio-tar = { git = "https://github.com/dr-bonez/tokio-tar.git" }
tokio-tungstenite = { version = "0.26.2", features = ["native-tls", "url"] }
tokio-util = { version = "0.7.9", features = ["io"] }
tor-cell = { version = "0.33", git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit" }
tor-cell = { version = "0.33", git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit", optional = true }
tor-hscrypto = { version = "0.33", features = [
"full",
], git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit" }
tor-hsservice = { version = "0.33", git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit" }
], git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit", optional = true }
tor-hsservice = { version = "0.33", git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit", optional = true }
tor-keymgr = { version = "0.33", features = [
"ephemeral-keystore",
], git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit" }
], git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit", optional = true }
tor-llcrypto = { version = "0.33", features = [
"full",
], git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit" }
tor-proto = { version = "0.33", git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit" }
], git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit", optional = true }
tor-proto = { version = "0.33", git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit", optional = true }
tor-rtcompat = { version = "0.33", features = [
"tokio",
"rustls",
], git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit" }
"tokio",
], git = "https://github.com/Start9Labs/arti.git", branch = "patch/disable-exit", optional = true }
torut = "0.2.1"
tower-service = "0.3.3"
tracing = "0.1.39"
tracing-error = "0.2.0"
@@ -259,11 +286,10 @@ unix-named-pipe = "0.2.0"
url = { version = "2.4.1", features = ["serde"] }
urlencoding = "2.1.3"
uuid = { version = "1.4.1", features = ["v4"] }
visit-rs = "0.1.1"
x25519-dalek = { version = "2.0.1", features = ["static_secrets"] }
zbus = "5.1.1"
zeroize = "1.6.0"
mail-send = { git = "https://github.com/dr-bonez/mail-send.git", branch = "main", optional = true }
rustls = "0.23.20"
rustls-pki-types = { version = "1.10.1", features = ["alloc"] }
[profile.test]
opt-level = 3

View File

@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::time::SystemTime;
use imbl_value::InternedString;
@@ -107,6 +108,7 @@ impl AccountInfo {
.map(|tor_key| tor_key.onion_address())
.collect(),
)?;
server_info.as_password_hash_mut().ser(&self.password)?;
db.as_private_mut().as_password_mut().ser(&self.password)?;
db.as_private_mut()
.as_ssh_privkey_mut()
@@ -119,12 +121,20 @@ impl AccountInfo {
key_store.as_onion_mut().insert_key(tor_key)?;
}
let cert_store = key_store.as_local_certs_mut();
cert_store
.as_root_key_mut()
.ser(Pem::new_ref(&self.root_ca_key))?;
cert_store
.as_root_cert_mut()
.ser(Pem::new_ref(&self.root_ca_cert))?;
if cert_store.as_root_cert().de()?.0 != self.root_ca_cert {
cert_store
.as_root_key_mut()
.ser(Pem::new_ref(&self.root_ca_key))?;
cert_store
.as_root_cert_mut()
.ser(Pem::new_ref(&self.root_ca_cert))?;
let int_key = crate::net::ssl::generate_key()?;
let int_cert =
crate::net::ssl::make_int_cert((&self.root_ca_key, &self.root_ca_cert), &int_key)?;
cert_store.as_int_key_mut().ser(&Pem(int_key))?;
cert_store.as_int_cert_mut().ser(&Pem(int_cert))?;
cert_store.as_leaves_mut().ser(&BTreeMap::new())?;
}
Ok(())
}

View File

@@ -4,7 +4,7 @@ use clap::{CommandFactory, FromArgMatches, Parser};
pub use models::ActionId;
use models::{PackageId, ReplayId};
use qrcode::QrCode;
use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler};
use rpc_toolkit::{Context, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use tracing::instrument;
use ts_rs::TS;
@@ -14,7 +14,7 @@ use crate::db::model::package::TaskSeverity;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::util::serde::{
display_serializable, HandlerExtSerde, StdinDeserializable, WithIoFormat,
HandlerExtSerde, StdinDeserializable, WithIoFormat, display_serializable,
};
pub fn action_api<C: Context>() -> ParentHandler<C> {

View File

@@ -220,7 +220,7 @@ pub fn check_password(hash: &str, password: &str) -> Result<(), Error> {
pub struct LoginParams {
password: String,
#[ts(skip)]
#[serde(rename = "__auth_userAgent")] // from Auth middleware
#[serde(rename = "__Auth_userAgent")] // from Auth middleware
user_agent: Option<String>,
#[serde(default)]
ephemeral: bool,
@@ -279,7 +279,7 @@ pub async fn login_impl<C: AuthContext>(
#[command(rename_all = "kebab-case")]
pub struct LogoutParams {
#[ts(skip)]
#[serde(rename = "__auth_session")] // from Auth middleware
#[serde(rename = "__Auth_session")] // from Auth middleware
session: InternedString,
}
@@ -373,7 +373,7 @@ fn display_sessions(params: WithIoFormat<ListParams>, arg: SessionList) -> Resul
pub struct ListParams {
#[arg(skip)]
#[ts(skip)]
#[serde(rename = "__auth_session")] // from Auth middleware
#[serde(rename = "__Auth_session")] // from Auth middleware
session: Option<InternedString>,
}
@@ -474,30 +474,19 @@ pub async fn reset_password_impl(
let old_password = old_password.unwrap_or_default().decrypt(&ctx)?;
let new_password = new_password.unwrap_or_default().decrypt(&ctx)?;
let mut account = ctx.account.write().await;
if !argon2::verify_encoded(&account.password, old_password.as_bytes())
.with_kind(crate::ErrorKind::IncorrectPassword)?
{
return Err(Error::new(
eyre!("Incorrect Password"),
crate::ErrorKind::IncorrectPassword,
));
}
account.set_password(&new_password)?;
let account_password = &account.password;
let account = account.clone();
ctx.db
.mutate(|d| {
d.as_public_mut()
.as_server_info_mut()
.as_password_hash_mut()
.ser(account_password)?;
account.save(d)?;
Ok(())
})
.await
.result
let account = ctx.account.mutate(|account| {
if !argon2::verify_encoded(&account.password, old_password.as_bytes())
.with_kind(crate::ErrorKind::IncorrectPassword)?
{
return Err(Error::new(
eyre!("Incorrect Password"),
crate::ErrorKind::IncorrectPassword,
));
}
account.set_password(&new_password)?;
Ok(account.clone())
})?;
ctx.db.mutate(|d| account.save(d)).await.result
}
#[instrument(skip_all)]

View File

@@ -317,7 +317,7 @@ async fn perform_backup(
.with_kind(ErrorKind::Filesystem)?;
os_backup_file
.write_all(&IoFormat::Json.to_vec(&OsBackup {
account: ctx.account.read().await.clone(),
account: ctx.account.peek(|a| a.clone()),
ui,
})?)
.await?;
@@ -342,7 +342,7 @@ async fn perform_backup(
let timestamp = Utc::now();
backup_guard.unencrypted_metadata.version = crate::version::Current::default().semver().into();
backup_guard.unencrypted_metadata.hostname = ctx.account.read().await.hostname.clone();
backup_guard.unencrypted_metadata.hostname = ctx.account.peek(|a| a.hostname.clone());
backup_guard.unencrypted_metadata.timestamp = timestamp.clone();
backup_guard.metadata.version = crate::version::Current::default().semver().into();
backup_guard.metadata.timestamp = Some(timestamp);

View File

@@ -11,14 +11,17 @@ use crate::context::config::ClientConfig;
use crate::net::web_server::{Acceptor, WebServer};
use crate::prelude::*;
use crate::registry::context::{RegistryConfig, RegistryContext};
use crate::registry::registry_router;
use crate::util::logger::LOGGER;
#[instrument(skip_all)]
async fn inner_main(config: &RegistryConfig) -> Result<(), Error> {
let server = async {
let ctx = RegistryContext::init(config).await?;
let mut server = WebServer::new(Acceptor::bind([ctx.listen]).await?);
server.serve_registry(ctx.clone());
let server = WebServer::new(
Acceptor::bind([ctx.listen]).await?,
registry_router(ctx.clone()),
);
let mut shutdown_recv = ctx.shutdown.subscribe();

View File

@@ -17,7 +17,7 @@ pub fn main(args: impl IntoIterator<Item = OsString>) {
if let Err(e) = CliApp::new(
|cfg: ClientConfig| Ok(CliContext::init(cfg.load()?)?),
crate::expanded_api(),
crate::main_api(),
)
.run(args)
{

View File

@@ -11,7 +11,8 @@ use crate::disk::fsck::RepairStrategy;
use crate::disk::main::DEFAULT_PASSWORD;
use crate::firmware::{check_for_firmware_update, update_firmware};
use crate::init::{InitPhases, STANDBY_MODE_PATH};
use crate::net::web_server::{UpgradableListener, WebServer};
use crate::net::gateway::UpgradableListener;
use crate::net::web_server::WebServer;
use crate::prelude::*;
use crate::progress::FullProgressTracker;
use crate::shutdown::Shutdown;
@@ -37,7 +38,7 @@ async fn setup_or_init(
let mut update_phase = handle.add_phase("Updating Firmware".into(), Some(10));
let mut reboot_phase = handle.add_phase("Rebooting".into(), Some(1));
server.serve_init(init_ctx);
server.serve_ui_for(init_ctx);
update_phase.start();
if let Err(e) = update_firmware(firmware).await {
@@ -93,7 +94,7 @@ async fn setup_or_init(
let ctx = InstallContext::init().await?;
server.serve_install(ctx.clone());
server.serve_ui_for(ctx.clone());
ctx.shutdown
.subscribe()
@@ -113,7 +114,7 @@ async fn setup_or_init(
{
let ctx = SetupContext::init(server, config)?;
server.serve_setup(ctx.clone());
server.serve_ui_for(ctx.clone());
let mut shutdown = ctx.shutdown.subscribe();
if let Some(shutdown) = shutdown.recv().await.expect("context dropped") {
@@ -149,7 +150,7 @@ async fn setup_or_init(
let init_phases = InitPhases::new(&handle);
let rpc_ctx_phases = InitRpcContextPhases::new(&handle);
server.serve_init(init_ctx);
server.serve_ui_for(init_ctx);
async {
disk_phase.start();
@@ -247,7 +248,7 @@ pub async fn main(
e,
)?;
server.serve_diagnostic(ctx.clone());
server.serve_ui_for(ctx.clone());
let shutdown = ctx.shutdown.subscribe().recv().await.unwrap();

View File

@@ -12,8 +12,9 @@ use tracing::instrument;
use crate::context::config::ServerConfig;
use crate::context::rpc::InitRpcContextPhases;
use crate::context::{DiagnosticContext, InitContext, RpcContext};
use crate::net::gateway::SelfContainedNetworkInterfaceListener;
use crate::net::web_server::{Acceptor, UpgradableListener, WebServer};
use crate::net::gateway::{BindTcp, SelfContainedNetworkInterfaceListener, UpgradableListener};
use crate::net::static_server::refresher;
use crate::net::web_server::{Acceptor, WebServer};
use crate::shutdown::Shutdown;
use crate::system::launch_metrics_task;
use crate::util::io::append_file;
@@ -38,7 +39,7 @@ async fn inner_main(
};
tokio::fs::write("/run/startos/initialized", "").await?;
server.serve_main(ctx.clone());
server.serve_ui_for(ctx.clone());
LOGGER.set_logfile(None);
handle.complete();
@@ -47,7 +48,7 @@ async fn inner_main(
let init_ctx = InitContext::init(config).await?;
let handle = init_ctx.progress.clone();
let rpc_ctx_phases = InitRpcContextPhases::new(&handle);
server.serve_init(init_ctx);
server.serve_ui_for(init_ctx);
let ctx = RpcContext::init(
&server.acceptor_setter(),
@@ -63,14 +64,14 @@ async fn inner_main(
)
.await?;
server.serve_main(ctx.clone());
server.serve_ui_for(ctx.clone());
handle.complete();
ctx
};
let (rpc_ctx, shutdown) = async {
crate::hostname::sync_hostname(&rpc_ctx.account.read().await.hostname).await?;
crate::hostname::sync_hostname(&rpc_ctx.account.peek(|a| a.hostname.clone())).await?;
let mut shutdown_recv = rpc_ctx.shutdown.subscribe();
@@ -147,9 +148,10 @@ pub fn main(args: impl IntoIterator<Item = OsString>) {
.build()
.expect("failed to initialize runtime");
let res = rt.block_on(async {
let mut server = WebServer::new(Acceptor::bind_upgradable(
SelfContainedNetworkInterfaceListener::bind(80),
));
let mut server = WebServer::new(
Acceptor::bind_upgradable(SelfContainedNetworkInterfaceListener::bind(BindTcp, 80)),
refresher(),
);
match inner_main(&mut server, &config).await {
Ok(a) => {
server.shutdown().await;
@@ -177,7 +179,7 @@ pub fn main(args: impl IntoIterator<Item = OsString>) {
e,
)?;
server.serve_diagnostic(ctx.clone());
server.serve_ui_for(ctx.clone());
let mut shutdown = ctx.shutdown.subscribe();

View File

@@ -1,29 +1,110 @@
use std::ffi::OsString;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use clap::Parser;
use futures::FutureExt;
use helpers::NonDetachingJoinHandle;
use rpc_toolkit::CliApp;
use tokio::signal::unix::signal;
use tracing::instrument;
use visit_rs::Visit;
use crate::context::CliContext;
use crate::context::config::ClientConfig;
use crate::net::web_server::{Acceptor, WebServer};
use crate::net::gateway::{Bind, BindTcp};
use crate::net::tls::TlsListener;
use crate::net::web_server::{Accept, Acceptor, MetadataVisitor, WebServer};
use crate::prelude::*;
use crate::tunnel::context::{TunnelConfig, TunnelContext};
use crate::tunnel::tunnel_router;
use crate::tunnel::web::TunnelCertHandler;
use crate::util::logger::LOGGER;
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum WebserverListener {
Http,
Https(SocketAddr),
}
impl<V: MetadataVisitor> Visit<V> for WebserverListener {
fn visit(&self, visitor: &mut V) -> <V as visit_rs::Visitor>::Result {
visitor.visit(self)
}
}
#[instrument(skip_all)]
async fn inner_main(config: &TunnelConfig) -> Result<(), Error> {
let server = async {
let ctx = TunnelContext::init(config).await?;
let mut server = WebServer::new(Acceptor::bind([ctx.listen]).await?);
server.serve_tunnel(ctx.clone());
let listen = ctx.listen;
let server = WebServer::new(
Acceptor::bind_map_dyn([(WebserverListener::Http, listen)]).await?,
tunnel_router(ctx.clone()),
);
let acceptor_setter = server.acceptor_setter();
let https_db = ctx.db.clone();
let https_thread: NonDetachingJoinHandle<()> = tokio::spawn(async move {
let mut sub = https_db.subscribe("/webserver".parse().unwrap()).await;
while {
while let Err(e) = async {
let webserver = https_db.peek().await.into_webserver();
if webserver.as_enabled().de()? {
let addr = webserver.as_listen().de()?.or_not_found("listen address")?;
acceptor_setter.send_if_modified(|a| {
let key = WebserverListener::Https(addr);
if !a.contains_key(&key) {
match (|| {
Ok::<_, Error>(TlsListener::new(
BindTcp.bind(addr)?,
TunnelCertHandler {
db: https_db.clone(),
crypto_provider: Arc::new(tokio_rustls::rustls::crypto::ring::default_provider()),
},
))
})() {
Ok(l) => {
a.retain(|k, _| *k == WebserverListener::Http);
a.insert(key, l.into_dyn());
true
}
Err(e) => {
tracing::error!("error adding ssl listener: {e}");
tracing::debug!("{e:?}");
false
}
}
} else {
false
}
});
} else {
acceptor_setter.send_if_modified(|a| {
let before = a.len();
a.retain(|k, _| *k == WebserverListener::Http);
a.len() != before
});
}
Ok::<_, Error>(())
}
.await
{
tracing::error!("error updating webserver bind: {e}");
tracing::debug!("{e:?}");
tokio::time::sleep(Duration::from_secs(5)).await;
}
sub.recv().await.is_some()
} {}
})
.into();
let mut shutdown_recv = ctx.shutdown.subscribe();
let sig_handler_ctx = ctx;
let sig_handler = tokio::spawn(async move {
let sig_handler: NonDetachingJoinHandle<()> = tokio::spawn(async move {
use tokio::signal::unix::SignalKind;
futures::future::select_all(
[
@@ -48,14 +129,16 @@ async fn inner_main(config: &TunnelConfig) -> Result<(), Error> {
.send(())
.map_err(|_| ())
.expect("send shutdown signal");
});
})
.into();
shutdown_recv
.recv()
.await
.with_kind(crate::ErrorKind::Unknown)?;
sig_handler.abort();
sig_handler.wait_for_abort().await.with_kind(ErrorKind::Unknown)?;
https_thread.wait_for_abort().await.with_kind(ErrorKind::Unknown)?;
Ok::<_, Error>(server)
}

View File

@@ -6,6 +6,7 @@ use std::sync::Arc;
use cookie::{Cookie, Expiration, SameSite};
use cookie_store::CookieStore;
use http::HeaderMap;
use imbl_value::InternedString;
use josekit::jwk::Jwk;
use once_cell::sync::OnceCell;
@@ -26,7 +27,7 @@ use crate::developer::{OS_DEVELOPER_KEY_PATH, default_developer_key_path};
use crate::middleware::auth::AuthContext;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::tunnel::context::TunnelContext;
use crate::util::io::read_file_to_string;
#[derive(Debug)]
pub struct CliContextSeed {
@@ -159,7 +160,7 @@ impl CliContext {
continue;
}
let pair = <ed25519::KeypairBytes as ed25519::pkcs8::DecodePrivateKey>::from_pkcs8_pem(
&std::fs::read_to_string(&self.developer_key_path)?,
&std::fs::read_to_string(path)?,
)
.with_kind(crate::ErrorKind::Pem)?;
let secret = ed25519_dalek::SecretKey::try_from(&pair.secret_key[..]).map_err(|_| {
@@ -171,7 +172,7 @@ impl CliContext {
return Ok(secret.into())
}
Err(Error::new(
eyre!("Developer Key does not exist! Please run `start-cli init` before running this command."),
eyre!("Developer Key does not exist! Please run `start-cli init-key` before running this command."),
crate::ErrorKind::Uninitialized
))
})
@@ -233,23 +234,28 @@ impl CliContext {
&self,
method: &str,
params: Value,
) -> Result<Value, RpcError>
) -> Result<Value, Error>
where
Self: CallRemote<RemoteContext>,
{
<Self as CallRemote<RemoteContext, Empty>>::call_remote(&self, method, params, Empty {})
.await
.map_err(Error::from)
.with_ctx(|e| (e.kind, method))
}
pub async fn call_remote_with<RemoteContext, T>(
&self,
method: &str,
params: Value,
extra: T,
) -> Result<Value, RpcError>
) -> Result<Value, Error>
where
Self: CallRemote<RemoteContext, T>,
{
<Self as CallRemote<RemoteContext, T>>::call_remote(&self, method, params, extra).await
<Self as CallRemote<RemoteContext, T>>::call_remote(&self, method, params, extra)
.await
.map_err(Error::from)
.with_ctx(|e| (e.kind, method))
}
}
impl AsRef<Jwk> for CliContext {
@@ -279,9 +285,15 @@ impl Context for CliContext {
)
}
}
impl AsRef<Client> for CliContext {
fn as_ref(&self) -> &Client {
&self.client
}
}
impl CallRemote<RpcContext> for CliContext {
async fn call_remote(&self, method: &str, params: Value, _: Empty) -> Result<Value, RpcError> {
if let Ok(local) = std::fs::read_to_string(RpcContext::LOCAL_AUTH_COOKIE_PATH) {
if let Ok(local) = read_file_to_string(RpcContext::LOCAL_AUTH_COOKIE_PATH).await {
self.cookie_store
.lock()
.unwrap()
@@ -298,7 +310,8 @@ impl CallRemote<RpcContext> for CliContext {
crate::middleware::signature::call_remote(
self,
self.rpc_url.clone(),
self.rpc_url.host_str().or_not_found("rpc url hostname")?,
HeaderMap::new(),
self.rpc_url.host_str(),
method,
params,
)
@@ -307,24 +320,11 @@ impl CallRemote<RpcContext> for CliContext {
}
impl CallRemote<DiagnosticContext> for CliContext {
async fn call_remote(&self, method: &str, params: Value, _: Empty) -> Result<Value, RpcError> {
if let Ok(local) = std::fs::read_to_string(TunnelContext::LOCAL_AUTH_COOKIE_PATH) {
self.cookie_store
.lock()
.unwrap()
.insert_raw(
&Cookie::build(("local", local))
.domain("localhost")
.expires(Expiration::Session)
.same_site(SameSite::Strict)
.build(),
&"http://localhost".parse()?,
)
.with_kind(crate::ErrorKind::Network)?;
}
crate::middleware::signature::call_remote(
self,
self.rpc_url.clone(),
self.rpc_url.host_str().or_not_found("rpc url hostname")?,
HeaderMap::new(),
self.rpc_url.host_str(),
method,
params,
)
@@ -336,7 +336,8 @@ impl CallRemote<InitContext> for CliContext {
crate::middleware::signature::call_remote(
self,
self.rpc_url.clone(),
self.rpc_url.host_str().or_not_found("rpc url hostname")?,
HeaderMap::new(),
self.rpc_url.host_str(),
method,
params,
)
@@ -348,7 +349,8 @@ impl CallRemote<SetupContext> for CliContext {
crate::middleware::signature::call_remote(
self,
self.rpc_url.clone(),
self.rpc_url.host_str().or_not_found("rpc url hostname")?,
HeaderMap::new(),
self.rpc_url.host_str(),
method,
params,
)
@@ -360,22 +362,11 @@ impl CallRemote<InstallContext> for CliContext {
crate::middleware::signature::call_remote(
self,
self.rpc_url.clone(),
self.rpc_url.host_str().or_not_found("rpc url hostname")?,
HeaderMap::new(),
self.rpc_url.host_str(),
method,
params,
)
.await
}
}
#[test]
fn test() {
let ctx = CliContext::init(ClientConfig::default()).unwrap();
ctx.runtime().unwrap().block_on(async {
reqwest::Client::new()
.get("http://example.com")
.send()
.await
.unwrap();
});
}

View File

@@ -6,10 +6,10 @@ use tokio::sync::broadcast::Sender;
use tokio::sync::watch;
use tracing::instrument;
use crate::Error;
use crate::context::config::ServerConfig;
use crate::progress::FullProgressTracker;
use crate::rpc_continuations::RpcContinuations;
use crate::Error;
pub struct InitContextSeed {
pub config: ServerConfig,

View File

@@ -1,11 +1,10 @@
use std::collections::{BTreeMap, BTreeSet};
use std::ffi::OsStr;
use std::future::Future;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use chrono::{TimeDelta, Utc};
@@ -18,36 +17,37 @@ use models::{ActionId, PackageId};
use reqwest::{Client, Proxy};
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{CallRemote, Context, Empty};
use tokio::sync::{broadcast, oneshot, watch, RwLock};
use tokio::sync::{RwLock, broadcast, oneshot, watch};
use tokio::time::Instant;
use tracing::instrument;
use super::setup::CURRENT_SECRET;
use crate::DATA_DIR;
use crate::account::AccountInfo;
use crate::auth::Sessions;
use crate::context::config::ServerConfig;
use crate::db::model::package::TaskSeverity;
use crate::db::model::Database;
use crate::db::model::package::TaskSeverity;
use crate::disk::OsPartitionInfo;
use crate::init::{check_time_is_synchronized, InitResult};
use crate::init::{InitResult, check_time_is_synchronized};
use crate::install::PKG_ARCHIVE_DIR;
use crate::lxc::LxcManager;
use crate::net::gateway::UpgradableListener;
use crate::net::net_controller::{NetController, NetService};
use crate::net::socks::DEFAULT_SOCKS_LISTEN;
use crate::net::utils::{find_eth_iface, find_wifi_iface};
use crate::net::web_server::{UpgradableListener, WebServerAcceptorSetter};
use crate::net::web_server::WebServerAcceptorSetter;
use crate::net::wifi::WpaCli;
use crate::prelude::*;
use crate::progress::{FullProgressTracker, PhaseProgressTrackerHandle};
use crate::rpc_continuations::{Guid, OpenAuthedContinuations, RpcContinuations};
use crate::service::ServiceMap;
use crate::service::action::update_tasks;
use crate::service::effects::callbacks::ServiceCallbacks;
use crate::service::ServiceMap;
use crate::shutdown::Shutdown;
use crate::util::io::delete_file;
use crate::util::lshw::LshwDevice;
use crate::util::sync::{SyncMutex, Watch};
use crate::{DATA_DIR, HOST_IP};
use crate::util::sync::{SyncMutex, SyncRwLock, Watch};
pub struct RpcContextSeed {
is_closed: AtomicBool,
@@ -58,7 +58,7 @@ pub struct RpcContextSeed {
pub ephemeral_sessions: SyncMutex<Sessions>,
pub db: TypedPatchDb<Database>,
pub sync_db: watch::Sender<u64>,
pub account: RwLock<AccountInfo>,
pub account: SyncRwLock<AccountInfo>,
pub net_controller: Arc<NetController>,
pub os_net_service: NetService,
pub s9pk_arch: Option<&'static str>,
@@ -225,7 +225,7 @@ impl RpcContext {
ephemeral_sessions: SyncMutex::new(Sessions::new()),
sync_db: watch::Sender::new(db.sequence().await),
db,
account: RwLock::new(account),
account: SyncRwLock::new(account),
callbacks: net_controller.callbacks.clone(),
net_controller,
os_net_service,
@@ -483,6 +483,11 @@ impl RpcContext {
<Self as CallRemote<RemoteContext, T>>::call_remote(&self, method, params, extra).await
}
}
impl AsRef<Client> for RpcContext {
fn as_ref(&self) -> &Client {
&self.client
}
}
impl AsRef<Jwk> for RpcContext {
fn as_ref(&self) -> &Jwk {
&CURRENT_SECRET

View File

@@ -10,24 +10,25 @@ use josekit::jwk::Jwk;
use patch_db::PatchDb;
use rpc_toolkit::Context;
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast::Sender;
use tokio::sync::OnceCell;
use tokio::sync::broadcast::Sender;
use tracing::instrument;
use ts_rs::TS;
use crate::MAIN_DATA;
use crate::account::AccountInfo;
use crate::context::config::ServerConfig;
use crate::context::RpcContext;
use crate::context::config::ServerConfig;
use crate::disk::OsPartitionInfo;
use crate::hostname::Hostname;
use crate::net::web_server::{UpgradableListener, WebServer, WebServerAcceptorSetter};
use crate::net::gateway::UpgradableListener;
use crate::net::web_server::{WebServer, WebServerAcceptorSetter};
use crate::prelude::*;
use crate::progress::FullProgressTracker;
use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations};
use crate::setup::SetupProgress;
use crate::shutdown::Shutdown;
use crate::util::net::WebSocketExt;
use crate::MAIN_DATA;
lazy_static::lazy_static! {
pub static ref CURRENT_SECRET: Jwk = Jwk::generate_ec_key(josekit::jwk::alg::ec::EcCurve::P256).unwrap_or_else(|e| {

View File

@@ -1,6 +1,7 @@
pub mod model;
pub mod prelude;
use std::panic::UnwindSafe;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
@@ -29,6 +30,26 @@ lazy_static::lazy_static! {
static ref PUBLIC: JsonPointer = "/public".parse().unwrap();
}
pub trait DbAccess<T>: Sized {
fn access<'a>(db: &'a Model<Self>) -> &'a Model<T>;
}
pub trait DbAccessMut<T>: DbAccess<T> {
fn access_mut<'a>(db: &'a mut Model<Self>) -> &'a mut Model<T>;
}
pub trait DbAccessByKey<T>: Sized {
type Key<'a>;
fn access_by_key<'a>(db: &'a Model<Self>, key: Self::Key<'_>) -> Option<&'a Model<T>>;
}
pub trait DbAccessMutByKey<T>: DbAccessByKey<T> {
fn access_mut_by_key<'a>(
db: &'a mut Model<Self>,
key: Self::Key<'_>,
) -> Option<&'a mut Model<T>>;
}
pub fn db<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand(
@@ -127,7 +148,7 @@ pub struct SubscribeParams {
#[ts(type = "string | null")]
pointer: Option<JsonPointer>,
#[ts(skip)]
#[serde(rename = "__auth_session")]
#[serde(rename = "__Auth_session")]
session: Option<InternedString>,
}

View File

@@ -1,5 +1,6 @@
use std::collections::{BTreeMap, BTreeSet, VecDeque};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use chrono::{DateTime, Utc};
use exver::{Version, VersionRange};
@@ -8,7 +9,6 @@ use imbl_value::InternedString;
use ipnet::IpNet;
use isocountry::CountryCode;
use itertools::Itertools;
use lazy_static::lazy_static;
use models::{GatewayId, PackageId};
use openssl::hash::MessageDigest;
use patch_db::{HasModel, Value};
@@ -16,11 +16,12 @@ use serde::{Deserialize, Serialize};
use ts_rs::TS;
use crate::account::AccountInfo;
use crate::db::DbAccessByKey;
use crate::db::model::Database;
use crate::db::model::package::AllPackageData;
use crate::net::acme::AcmeProvider;
use crate::net::forward::START9_BRIDGE_IFACE;
use crate::net::host::binding::{AddSslOptions, BindInfo, BindOptions, NetInfo};
use crate::net::host::Host;
use crate::net::host::binding::{AddSslOptions, BindInfo, BindOptions, NetInfo};
use crate::net::utils::ipv6_is_local;
use crate::net::vhost::AlpnInfo;
use crate::prelude::*;
@@ -30,7 +31,7 @@ use crate::util::cpupower::Governor;
use crate::util::lshw::LshwDevice;
use crate::util::serde::MaybeUtf8String;
use crate::version::{Current, VersionT};
use crate::{ARCH, HOST_IP, PLATFORM};
use crate::{ARCH, PLATFORM};
#[derive(Debug, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
@@ -122,11 +123,20 @@ impl Public {
kiosk,
},
package_data: AllPackageData::default(),
ui: serde_json::from_str(include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../web/patchdb-ui-seed.json"
)))
.with_kind(ErrorKind::Deserialization)?,
ui: {
#[cfg(feature = "startd")]
{
serde_json::from_str(include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../web/patchdb-ui-seed.json"
)))
.with_kind(ErrorKind::Deserialization)?
}
#[cfg(not(feature = "startd"))]
{
Value::Null
}
},
})
}
}
@@ -216,64 +226,9 @@ pub struct NetworkInterfaceInfo {
pub name: Option<InternedString>,
pub public: Option<bool>,
pub secure: Option<bool>,
pub ip_info: Option<IpInfo>,
pub ip_info: Option<Arc<IpInfo>>,
}
impl NetworkInterfaceInfo {
pub fn loopback() -> (&'static GatewayId, &'static Self) {
lazy_static! {
static ref LO: GatewayId = GatewayId::from(InternedString::intern("lo"));
static ref LOOPBACK: NetworkInterfaceInfo = NetworkInterfaceInfo {
name: Some(InternedString::from_static("Loopback")),
public: Some(false),
secure: Some(true),
ip_info: Some(IpInfo {
name: "lo".into(),
scope_id: 1,
device_type: None,
subnets: [
IpNet::new(Ipv4Addr::LOCALHOST.into(), 8).unwrap(),
IpNet::new(Ipv6Addr::LOCALHOST.into(), 128).unwrap(),
]
.into_iter()
.collect(),
lan_ip: [
IpAddr::from(Ipv4Addr::LOCALHOST),
IpAddr::from(Ipv6Addr::LOCALHOST)
]
.into_iter()
.collect(),
wan_ip: None,
ntp_servers: Default::default(),
dns_servers: Default::default(),
}),
};
}
(&*LO, &*LOOPBACK)
}
pub fn lxc_bridge() -> (&'static GatewayId, &'static Self) {
lazy_static! {
static ref LXCBR0: GatewayId =
GatewayId::from(InternedString::intern(START9_BRIDGE_IFACE));
static ref LXC_BRIDGE: NetworkInterfaceInfo = NetworkInterfaceInfo {
name: Some(InternedString::from_static("LXC Bridge Interface")),
public: Some(false),
secure: Some(true),
ip_info: Some(IpInfo {
name: START9_BRIDGE_IFACE.into(),
scope_id: 0,
device_type: None,
subnets: [IpNet::new(HOST_IP.into(), 24).unwrap()]
.into_iter()
.collect(),
lan_ip: [IpAddr::from(HOST_IP)].into_iter().collect(),
wan_ip: None,
ntp_servers: Default::default(),
dns_servers: Default::default(),
}),
};
}
(&*LXCBR0, &*LXC_BRIDGE)
}
pub fn public(&self) -> bool {
self.public.unwrap_or_else(|| {
!self.ip_info.as_ref().map_or(true, |ip_info| {
@@ -308,7 +263,7 @@ impl NetworkInterfaceInfo {
self.secure.unwrap_or_else(|| {
self.ip_info.as_ref().map_or(false, |ip_info| {
ip_info.device_type == Some(NetworkInterfaceType::Wireguard)
})
}) && !self.public()
})
}
}
@@ -333,13 +288,15 @@ pub struct IpInfo {
pub dns_servers: OrdSet<IpAddr>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Serialize, TS)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, TS)]
#[ts(export)]
#[serde(rename_all = "kebab-case")]
pub enum NetworkInterfaceType {
Ethernet,
Wireless,
Bridge,
Wireguard,
Loopback,
}
#[derive(Debug, Deserialize, Serialize, HasModel, TS)]
@@ -349,6 +306,19 @@ pub enum NetworkInterfaceType {
pub struct AcmeSettings {
pub contact: Vec<String>,
}
impl DbAccessByKey<AcmeSettings> for Database {
type Key<'a> = &'a AcmeProvider;
fn access_by_key<'a>(
db: &'a Model<Self>,
key: Self::Key<'_>,
) -> Option<&'a Model<AcmeSettings>> {
db.as_public()
.as_server_info()
.as_network()
.as_acme()
.as_idx(key)
}
}
#[derive(Debug, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]

View File

@@ -1,6 +1,7 @@
use std::collections::{BTreeMap, BTreeSet};
use std::marker::PhantomData;
use std::str::FromStr;
use std::sync::Arc;
use chrono::{DateTime, Utc};
use imbl::OrdMap;
@@ -167,6 +168,21 @@ impl<T> Model<Option<T>> {
}
}
impl<T> Model<Arc<T>> {
pub fn deref(self) -> Model<T> {
use patch_db::ModelExt;
self.transmute(|a| a)
}
pub fn as_deref(&self) -> &Model<T> {
use patch_db::ModelExt;
self.transmute_ref(|a| a)
}
pub fn as_deref_mut(&mut self) -> &mut Model<T> {
use patch_db::ModelExt;
self.transmute_mut(|a| a)
}
}
pub trait Map: DeserializeOwned + Serialize {
type Key;
type Value;
@@ -193,10 +209,10 @@ where
A: serde::Serialize + serde::de::DeserializeOwned + Ord,
B: serde::Serialize + serde::de::DeserializeOwned,
{
type Key = A;
type Key = JsonKey<A>;
type Value = B;
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
serde_json::to_string(key).with_kind(ErrorKind::Serialization)
serde_json::to_string(&key.0).with_kind(ErrorKind::Serialization)
}
}
@@ -216,13 +232,18 @@ impl<T: Map> Model<T>
where
T::Value: Serialize,
{
pub fn insert(&mut self, key: &T::Key, value: &T::Value) -> Result<(), Error> {
pub fn insert_model(
&mut self,
key: &T::Key,
value: Model<T::Value>,
) -> Result<Option<Model<T::Value>>, Error> {
use patch_db::ModelExt;
use serde::ser::Error;
let v = patch_db::value::to_value(value)?;
let v = value.into_value();
match &mut self.value {
Value::Object(o) => {
o.insert(T::key_string(key)?, v);
Ok(())
let prev = o.insert(T::key_string(key)?, v);
Ok(prev.map(|v| Model::from_value(v)))
}
v => Err(patch_db::value::Error {
source: patch_db::value::ErrorSource::custom(format!("expected object found {v}")),
@@ -231,6 +252,13 @@ where
.into()),
}
}
pub fn insert(
&mut self,
key: &T::Key,
value: &T::Value,
) -> Result<Option<Model<T::Value>>, Error> {
self.insert_model(key, Model::new(value)?)
}
pub fn upsert<F>(&mut self, key: &T::Key, value: F) -> Result<&mut Model<T::Value>, Error>
where
F: FnOnce() -> Result<T::Value, Error>,
@@ -257,22 +285,6 @@ where
.into()),
}
}
pub fn insert_model(&mut self, key: &T::Key, value: Model<T::Value>) -> Result<(), Error> {
use patch_db::ModelExt;
use serde::ser::Error;
let v = value.into_value();
match &mut self.value {
Value::Object(o) => {
o.insert(T::key_string(key)?, v);
Ok(())
}
v => Err(patch_db::value::Error {
source: patch_db::value::ErrorSource::custom(format!("expected object found {v}")),
kind: patch_db::value::ErrorKind::Serialization,
}
.into()),
}
}
}
impl<T: Map> Model<T>
@@ -437,6 +449,12 @@ impl<T> std::ops::DerefMut for JsonKey<T> {
&mut self.0
}
}
impl<T: DeserializeOwned> FromStr for JsonKey<T> {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
serde_json::from_str(s).with_kind(ErrorKind::Deserialization)
}
}
impl<T: Serialize> Serialize for JsonKey<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
@@ -449,7 +467,7 @@ impl<T: Serialize> Serialize for JsonKey<T> {
}
}
// { "foo": "bar" } -> "{ \"foo\": \"bar\" }"
impl<'de, T: Serialize + DeserializeOwned> Deserialize<'de> for JsonKey<T> {
impl<'de, T: DeserializeOwned> Deserialize<'de> for JsonKey<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,

View File

@@ -6,9 +6,9 @@ use models::PackageId;
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use crate::Error;
use crate::prelude::*;
use crate::util::PathOrUrl;
use crate::Error;
#[derive(Clone, Debug, Default, Deserialize, Serialize, HasModel, TS)]
#[model = "Model<Self>"]

View File

@@ -80,23 +80,6 @@ impl<Fs: FileSystem> FileSystem for IdMapped<Fs> {
}
Ok(())
}
async fn mount<P: AsRef<Path> + Send>(
&self,
mountpoint: P,
mount_type: MountType,
) -> Result<(), Error> {
self.pre_mount(mountpoint.as_ref()).await?;
Command::new("mount.next")
.args(
default_mount_command(self, mountpoint, mount_type)
.await?
.get_args(),
)
.invoke(ErrorKind::Filesystem)
.await?;
Ok(())
}
async fn source_hash(
&self,
) -> Result<GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize>, Error> {

View File

@@ -10,8 +10,8 @@ use tracing::instrument;
use super::filesystem::{FileSystem, MountType, ReadOnly, ReadWrite};
use super::util::unmount;
use crate::util::{Invoke, Never};
use crate::Error;
use crate::util::{Invoke, Never};
pub const TMP_MOUNTPOINT: &'static str = "/media/startos/tmp";

View File

@@ -8,7 +8,7 @@ use const_format::formatcp;
use futures::{StreamExt, TryStreamExt};
use itertools::Itertools;
use models::ResultExt;
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerArgs, HandlerExt, ParentHandler};
use rpc_toolkit::{Context, Empty, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use tracing::instrument;
@@ -17,15 +17,16 @@ use ts_rs::TS;
use crate::account::AccountInfo;
use crate::context::config::ServerConfig;
use crate::context::{CliContext, InitContext, RpcContext};
use crate::db::model::public::ServerStatus;
use crate::db::model::Database;
use crate::db::model::public::ServerStatus;
use crate::developer::OS_DEVELOPER_KEY_PATH;
use crate::hostname::Hostname;
use crate::middleware::auth::AuthContext;
use crate::net::gateway::UpgradableListener;
use crate::net::net_controller::{NetController, NetService};
use crate::net::socks::DEFAULT_SOCKS_LISTEN;
use crate::net::utils::find_wifi_iface;
use crate::net::web_server::{UpgradableListener, WebServerAcceptorSetter};
use crate::net::web_server::WebServerAcceptorSetter;
use crate::prelude::*;
use crate::progress::{
FullProgress, FullProgressTracker, PhaseProgressTrackerHandle, PhasedProgressBar, ProgressUnits,
@@ -34,10 +35,10 @@ use crate::rpc_continuations::{Guid, RpcContinuation};
use crate::s9pk::v2::pack::{CONTAINER_DATADIR, CONTAINER_TOOL};
use crate::ssh::SSH_DIR;
use crate::system::{get_mem_info, sync_kiosk};
use crate::util::io::{open_file, IOHook};
use crate::util::io::{IOHook, open_file};
use crate::util::lshw::lshw;
use crate::util::net::WebSocketExt;
use crate::util::{cpupower, Invoke};
use crate::util::{Invoke, cpupower};
use crate::{Error, MAIN_DATA, PACKAGE_DATA};
pub const SYSTEM_REBUILD_PATH: &str = "/media/startos/config/system-rebuild";

View File

@@ -7,7 +7,7 @@ use clap::builder::ValueParserFactory;
use clap::{CommandFactory, FromArgMatches, Parser, value_parser};
use color_eyre::eyre::eyre;
use exver::VersionRange;
use futures::{AsyncWriteExt, StreamExt};
use futures::StreamExt;
use imbl_value::{InternedString, json};
use itertools::Itertools;
use models::{FromStrParser, VersionString};
@@ -15,7 +15,6 @@ use reqwest::Url;
use reqwest::header::{CONTENT_LENGTH, HeaderMap};
use rpc_toolkit::HandlerArgs;
use rpc_toolkit::yajrc::RpcError;
use rustyline_async::ReadlineEvent;
use serde::{Deserialize, Serialize};
use tokio::sync::oneshot;
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
@@ -34,6 +33,7 @@ use crate::upload::upload;
use crate::util::Never;
use crate::util::io::open_file;
use crate::util::net::WebSocketExt;
use crate::util::tui::choose;
pub const PKG_ARCHIVE_DIR: &str = "package-data/archive";
pub const PKG_PUBLIC_DIR: &str = "package-data/public";
@@ -175,7 +175,7 @@ pub async fn install(
#[serde(rename_all = "camelCase")]
pub struct SideloadParams {
#[ts(skip)]
#[serde(rename = "__auth_session")]
#[serde(rename = "__Auth_session")]
session: Option<InternedString>,
}
@@ -483,47 +483,19 @@ pub async fn cli_install(
let version = if packages.best.len() == 1 {
packages.best.pop_first().map(|(k, _)| k).unwrap()
} else {
println!(
"Multiple flavors of {id} found. Please select one of the following versions to install:"
);
let version;
loop {
let (mut read, mut output) = rustyline_async::Readline::new("> ".into())
.with_kind(ErrorKind::Filesystem)?;
for (idx, version) in packages.best.keys().enumerate() {
output
.write_all(format!(" {}) {}\n", idx + 1, version).as_bytes())
.await?;
read.add_history_entry(version.to_string());
}
if let ReadlineEvent::Line(line) = read.readline().await? {
let trimmed = line.trim();
match trimmed.parse() {
Ok(v) => {
if let Some((k, _)) = packages.best.remove_entry(&v) {
version = k;
break;
}
}
Err(_) => match trimmed.parse::<usize>() {
Ok(i) if (1..=packages.best.len()).contains(&i) => {
version = packages.best.keys().nth(i - 1).unwrap().clone();
break;
}
_ => (),
},
}
eprintln!("invalid selection: {trimmed}");
println!("Please select one of the following versions to install:");
} else {
return Err(Error::new(
eyre!("Could not determine precise version to install"),
ErrorKind::InvalidRequest,
)
.into());
}
}
version
let versions = packages.best.keys().collect::<Vec<_>>();
let version = choose(
&format!(
concat!(
"Multiple flavors of {id} found. ",
"Please select one of the following versions to install:"
),
id = id
),
&versions,
)
.await?;
(*version).clone()
};
ctx.call_remote::<RpcContext>(
&method.join("."),

View File

@@ -80,17 +80,16 @@ use imbl_value::Value;
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{
CallRemoteHandler, Context, Empty, HandlerExt, ParentHandler, from_fn, from_fn_async,
from_fn_blocking,
from_fn_async_local, from_fn_blocking,
};
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use crate::context::{
CliContext, DiagnosticContext, InitContext, InstallContext, RpcContext, SetupContext,
};
use crate::context::{CliContext, DiagnosticContext, InitContext, RpcContext};
use crate::disk::fsck::RequiresReboot;
use crate::registry::context::{RegistryContext, RegistryUrlParams};
use crate::system::kiosk;
use crate::tunnel::context::TunnelUrlParams;
use crate::util::serde::{HandlerExtSerde, WithIoFormat, display_serializable};
#[derive(Deserialize, Serialize, Parser, TS)]
@@ -139,6 +138,20 @@ pub fn main_api<C: Context>() -> ParentHandler<C> {
.with_about("Display the API that is currently serving")
.with_call_remote::<CliContext>(),
)
.subcommand(
"state",
from_fn(|_: InitContext| Ok::<_, Error>(ApiState::Initializing))
.with_metadata("authenticated", Value::Bool(false))
.with_about("Display the API that is currently serving")
.with_call_remote::<CliContext>(),
)
.subcommand(
"state",
from_fn(|_: DiagnosticContext| Ok::<_, Error>(ApiState::Error))
.with_metadata("authenticated", Value::Bool(false))
.with_about("Display the API that is currently serving")
.with_call_remote::<CliContext>(),
)
.subcommand(
"server",
server::<C>()
@@ -191,6 +204,19 @@ pub fn main_api<C: Context>() -> ParentHandler<C> {
)
.no_cli(),
)
.subcommand(
"registry",
registry::registry_api::<CliContext>().with_about("Commands related to the registry"),
)
.subcommand(
"tunnel",
CallRemoteHandler::<RpcContext, _, _, TunnelUrlParams>::new(tunnel::api::tunnel_api())
.no_cli(),
)
.subcommand(
"tunnel",
tunnel::api::tunnel_api::<CliContext>().with_about("Commands related to StartTunnel"),
)
.subcommand(
"s9pk",
s9pk::rpc::s9pk().with_about("Commands for interacting with s9pk files"),
@@ -198,6 +224,29 @@ pub fn main_api<C: Context>() -> ParentHandler<C> {
.subcommand(
"util",
util::rpc::util::<C>().with_about("Command for calculating the blake3 hash of a file"),
)
.subcommand(
"init-key",
from_fn_async(developer::init)
.no_display()
.with_about("Create developer key if it doesn't exist"),
)
.subcommand(
"pubkey",
from_fn_blocking(developer::pubkey)
.with_about("Get public key for developer private key"),
)
.subcommand(
"diagnostic",
diagnostic::diagnostic::<C>()
.with_about("Commands to display logs, restart the server, etc"),
)
.subcommand("init", init::init_api::<C>())
.subcommand("setup", setup::setup::<C>())
.subcommand(
"install",
os_install::install::<C>()
.with_about("Commands to list disk info, install StartOS, and reboot"),
);
if &*PLATFORM != "raspberrypi" {
api = api.subcommand("kiosk", kiosk::<C>());
@@ -343,7 +392,7 @@ pub fn package<C: Context>() -> ParentHandler<C> {
)
.subcommand(
"install",
from_fn_async(install::cli_install)
from_fn_async_local(install::cli_install)
.no_display()
.with_about("Install a package from a marketplace or via sideloading"),
)
@@ -464,13 +513,6 @@ pub fn package<C: Context>() -> ParentHandler<C> {
backup::package_backup::<C>()
.with_about("Commands for restoring package(s) from backup"),
)
.subcommand("connect", from_fn_async(service::connect_rpc).no_cli())
.subcommand(
"connect",
from_fn_async(service::connect_rpc_cli)
.no_display()
.with_about("Connect to a LXC container"),
)
.subcommand(
"attach",
from_fn_async(service::attach)
@@ -484,127 +526,3 @@ pub fn package<C: Context>() -> ParentHandler<C> {
net::host::host_api::<C>().with_about("Manage network hosts for a package"),
)
}
pub fn diagnostic_api() -> ParentHandler<DiagnosticContext> {
ParentHandler::new()
.subcommand(
"git-info",
from_fn(|_: DiagnosticContext| version::git_info())
.with_metadata("authenticated", Value::Bool(false))
.with_about("Display the githash of StartOS CLI"),
)
.subcommand(
"echo",
from_fn(echo::<DiagnosticContext>)
.with_about("Echo a message")
.with_call_remote::<CliContext>(),
)
.subcommand(
"state",
from_fn(|_: DiagnosticContext| Ok::<_, Error>(ApiState::Error))
.with_metadata("authenticated", Value::Bool(false))
.with_about("Display the API that is currently serving")
.with_call_remote::<CliContext>(),
)
.subcommand(
"diagnostic",
diagnostic::diagnostic::<DiagnosticContext>()
.with_about("Diagnostic commands i.e. logs, restart, rebuild"),
)
}
pub fn init_api() -> ParentHandler<InitContext> {
ParentHandler::new()
.subcommand(
"git-info",
from_fn(|_: InitContext| version::git_info())
.with_metadata("authenticated", Value::Bool(false))
.with_about("Display the githash of StartOS CLI"),
)
.subcommand(
"echo",
from_fn(echo::<InitContext>)
.with_about("Echo a message")
.with_call_remote::<CliContext>(),
)
.subcommand(
"state",
from_fn(|_: InitContext| Ok::<_, Error>(ApiState::Initializing))
.with_metadata("authenticated", Value::Bool(false))
.with_about("Display the API that is currently serving")
.with_call_remote::<CliContext>(),
)
.subcommand(
"init",
init::init_api::<InitContext>()
.with_about("Commands to get logs or initialization progress"),
)
}
pub fn setup_api() -> ParentHandler<SetupContext> {
ParentHandler::new()
.subcommand(
"git-info",
from_fn(|_: SetupContext| version::git_info())
.with_metadata("authenticated", Value::Bool(false))
.with_about("Display the githash of StartOS CLI"),
)
.subcommand(
"echo",
from_fn(echo::<SetupContext>)
.with_about("Echo a message")
.with_call_remote::<CliContext>(),
)
.subcommand("setup", setup::setup::<SetupContext>())
}
pub fn install_api() -> ParentHandler<InstallContext> {
ParentHandler::new()
.subcommand(
"git-info",
from_fn(|_: InstallContext| version::git_info())
.with_metadata("authenticated", Value::Bool(false))
.with_about("Display the githash of StartOS CLI"),
)
.subcommand(
"echo",
from_fn(echo::<InstallContext>)
.with_about("Echo a message")
.with_call_remote::<CliContext>(),
)
.subcommand(
"install",
os_install::install::<InstallContext>()
.with_about("Commands to list disk info, install StartOS, and reboot"),
)
}
pub fn expanded_api() -> ParentHandler<CliContext> {
main_api()
.subcommand(
"init",
from_fn_async(developer::init)
.no_display()
.with_about("Create developer key if it doesn't exist"),
)
.subcommand(
"pubkey",
from_fn_blocking(developer::pubkey)
.with_about("Get public key for developer private key"),
)
.subcommand(
"diagnostic",
diagnostic::diagnostic::<CliContext>()
.with_about("Commands to display logs, restart the server, etc"),
)
.subcommand("setup", setup::setup::<CliContext>())
.subcommand(
"install",
os_install::install::<CliContext>()
.with_about("Commands to list disk info, install StartOS, and reboot"),
)
.subcommand(
"registry",
registry::registry_api::<CliContext>().with_about("Commands related to the registry"),
)
}

View File

@@ -15,7 +15,7 @@ use itertools::Itertools;
use models::{FromStrParser, PackageId};
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{
from_fn_async, CallRemote, Context, Empty, HandlerArgs, HandlerExt, HandlerFor, ParentHandler,
CallRemote, Context, Empty, HandlerArgs, HandlerExt, HandlerFor, ParentHandler, from_fn_async,
};
use serde::de::{self, DeserializeOwned};
use serde::{Deserialize, Serialize};
@@ -30,9 +30,9 @@ use crate::error::ResultExt;
use crate::lxc::ContainerId;
use crate::prelude::*;
use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations};
use crate::util::Invoke;
use crate::util::net::WebSocketExt;
use crate::util::serde::Reversible;
use crate::util::Invoke;
#[pin_project::pin_project]
pub struct LogStream {

View File

@@ -10,7 +10,6 @@ use imbl_value::{InOMap, InternedString};
use models::{FromStrParser, InvalidId, PackageId};
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{GenericRpcMethod, RpcRequest, RpcResponse};
use rustyline_async::{ReadlineEvent, SharedWriter};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
@@ -470,115 +469,6 @@ pub async fn connect(ctx: &RpcContext, container: &LxcContainer) -> Result<Guid,
Ok(guid)
}
pub async fn connect_cli(ctx: &CliContext, guid: Guid) -> Result<(), Error> {
use futures::SinkExt;
use tokio_tungstenite::tungstenite::Message;
let mut ws = ctx.ws_continuation(guid).await?;
let (mut input, mut output) =
rustyline_async::Readline::new("> ".into()).with_kind(ErrorKind::Filesystem)?;
async fn handle_message(
msg: Option<Result<Message, tokio_tungstenite::tungstenite::Error>>,
output: &mut SharedWriter,
) -> Result<bool, Error> {
match msg {
None => return Ok(true),
Some(Ok(Message::Text(txt))) => match serde_json::from_str::<RpcResponse>(&txt) {
Ok(RpcResponse { result: Ok(a), .. }) => {
output
.write_all(
(serde_json::to_string(&a).with_kind(ErrorKind::Serialization)? + "\n")
.as_bytes(),
)
.await?;
}
Ok(RpcResponse { result: Err(e), .. }) => {
let e: Error = e.into();
tracing::error!("{e}");
tracing::debug!("{e:?}");
}
Err(e) => {
tracing::error!("Error Parsing RPC response: {e}");
tracing::debug!("{e:?}");
}
},
Some(Ok(_)) => (),
Some(Err(e)) => {
return Err(Error::new(e, ErrorKind::Network));
}
};
Ok(false)
}
loop {
tokio::select! {
line = input.readline() => {
let line = line.with_kind(ErrorKind::Filesystem)?;
if let ReadlineEvent::Line(line) = line {
input.add_history_entry(line.clone());
if serde_json::from_str::<RpcRequest>(&line).is_ok() {
ws.send(Message::Text(line.into()))
.await
.with_kind(ErrorKind::Network)?;
} else {
match shell_words::split(&line) {
Ok(command) => {
if let Some((method, rest)) = command.split_first() {
let mut params = InOMap::new();
for arg in rest {
if let Some((name, value)) = arg.split_once('=') {
params.insert(InternedString::intern(name), if value.is_empty() {
Value::Null
} else if let Ok(v) = serde_json::from_str(value) {
v
} else {
Value::String(Arc::new(value.into()))
});
} else {
tracing::error!("argument without a value: {arg}");
tracing::debug!("help: set the value of {arg} with `{arg}=...`");
continue;
}
}
ws.send(Message::Text(match serde_json::to_string(&RpcRequest {
id: None,
method: GenericRpcMethod::new(method.into()),
params: Value::Object(params),
}) {
Ok(a) => a.into(),
Err(e) => {
tracing::error!("Error Serializing Request: {e}");
tracing::debug!("{e:?}");
continue;
}
})).await.with_kind(ErrorKind::Network)?;
if handle_message(ws.next().await, &mut output).await? {
break
}
}
}
Err(e) => {
tracing::error!("{e}");
tracing::debug!("{e:?}");
}
}
}
} else {
ws.send(Message::Close(None)).await.with_kind(ErrorKind::Network)?;
}
}
msg = ws.next() => {
if handle_message(msg, &mut output).await? {
break;
}
}
}
}
Ok(())
}
pub async fn stats(ctx: RpcContext) -> Result<BTreeMap<PackageId, Option<ServiceStats>>, Error> {
let ids = ctx.db.peek().await.as_public().as_package_data().keys()?;

View File

@@ -27,14 +27,11 @@ use tokio::sync::Mutex;
use crate::auth::{Sessions, check_password, write_shadow};
use crate::context::RpcContext;
use crate::db::model::Database;
use crate::middleware::signature::{SignatureAuth, SignatureAuthContext};
use crate::prelude::*;
use crate::rpc_continuations::OpenAuthedContinuations;
use crate::sign::AnyVerifyingKey;
use crate::util::Invoke;
use crate::util::io::{create_file_mod, read_file_to_string};
use crate::util::iter::TransposeResultIterExt;
use crate::util::serde::BASE64;
use crate::util::sync::SyncMutex;
@@ -66,65 +63,6 @@ pub trait AuthContext: SignatureAuthContext {
}
}
impl SignatureAuthContext for RpcContext {
type Database = Database;
type AdditionalMetadata = ();
type CheckPubkeyRes = ();
fn db(&self) -> &TypedPatchDb<Self::Database> {
&self.db
}
async fn sig_context(
&self,
) -> impl IntoIterator<Item = Result<impl AsRef<str> + Send, Error>> + Send {
let peek = self.db.peek().await;
self.account
.read()
.await
.hostnames()
.into_iter()
.map(Ok)
.chain(
peek.as_public()
.as_server_info()
.as_network()
.as_host()
.as_public_domains()
.keys()
.map(|k| k.into_iter())
.transpose(),
)
.chain(
peek.as_public()
.as_server_info()
.as_network()
.as_host()
.as_private_domains()
.de()
.map(|k| k.into_iter())
.transpose(),
)
.collect::<Vec<_>>()
}
fn check_pubkey(
db: &Model<Self::Database>,
pubkey: Option<&AnyVerifyingKey>,
_: Self::AdditionalMetadata,
) -> Result<Self::CheckPubkeyRes, Error> {
if let Some(pubkey) = pubkey {
if db.as_private().as_auth_pubkeys().de()?.contains(pubkey) {
return Ok(());
}
}
Err(Error::new(
eyre!("Developer Key is not authorized"),
ErrorKind::IncorrectPassword,
))
}
async fn post_auth_hook(&self, _: Self::CheckPubkeyRes, _: &RpcRequest) -> Result<(), Error> {
Ok(())
}
}
impl AuthContext for RpcContext {
const LOCAL_AUTH_COOKIE_PATH: &str = "/run/startos/rpc.authcookie";
const LOCAL_AUTH_COOKIE_OWNERSHIP: &str = "root:startos";
@@ -439,7 +377,7 @@ impl<C: AuthContext> Middleware<C> for Auth {
));
}
if let Some(user_agent) = self.user_agent.as_ref().and_then(|h| h.to_str().ok()) {
request.params["__auth_userAgent"] =
request.params["__Auth_userAgent"] =
Value::String(Arc::new(user_agent.to_owned()))
// TODO: will this panic?
}
@@ -458,7 +396,7 @@ impl<C: AuthContext> Middleware<C> for Auth {
{
match HasValidSession::from_header(self.cookie.as_ref(), context).await? {
HasValidSession(SessionType::Session(s)) if metadata.get_session => {
request.params["__auth_session"] =
request.params["__Auth_session"] =
Value::String(Arc::new(s.hashed().deref().to_owned()));
}
_ => (),

View File

@@ -0,0 +1,55 @@
use std::net::SocketAddr;
use axum::extract::Request;
use axum::response::Response;
use imbl_value::json;
use rpc_toolkit::Middleware;
use serde::Deserialize;
#[derive(Clone, Default)]
pub struct ConnectInfo {
peer_addr: Option<SocketAddr>,
local_addr: Option<SocketAddr>,
}
impl ConnectInfo {
pub fn new() -> Self {
Self::default()
}
}
#[derive(Deserialize)]
pub struct Metadata {
get_connect_info: bool,
}
impl<Context: Send + Sync + 'static> Middleware<Context> for ConnectInfo {
type Metadata = Metadata;
async fn process_http_request(
&mut self,
_: &Context,
request: &mut Request,
) -> Result<(), Response> {
if let Some(axum::extract::ConnectInfo((peer, local))) = request.extensions().get().cloned()
{
self.peer_addr = Some(peer);
self.local_addr = Some(local);
}
Ok(())
}
async fn process_rpc_request(
&mut self,
_: &Context,
metadata: Self::Metadata,
request: &mut rpc_toolkit::RpcRequest,
) -> Result<(), rpc_toolkit::RpcResponse> {
if metadata.get_connect_info {
if let Some(peer_addr) = self.peer_addr {
request.params["__ConnectInfo_peer_addr"] = json!(peer_addr);
}
if let Some(local_addr) = self.local_addr {
request.params["__ConnectInfo_local_addr"] = json!(local_addr);
}
}
Ok(())
}
}

View File

@@ -1,4 +1,5 @@
pub mod auth;
pub mod connect_info;
pub mod cors;
pub mod db;
pub mod signature;

View File

@@ -5,7 +5,8 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use axum::body::Body;
use axum::extract::Request;
use http::HeaderValue;
use http::{HeaderMap, HeaderValue};
use reqwest::Client;
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{Context, Middleware, RpcRequest, RpcResponse};
use serde::Deserialize;
@@ -13,13 +14,17 @@ use serde::de::DeserializeOwned;
use tokio::sync::Mutex;
use url::Url;
use crate::context::CliContext;
use crate::context::{CliContext, RpcContext};
use crate::db::model::Database;
use crate::prelude::*;
use crate::sign::commitment::Commitment;
use crate::sign::commitment::request::RequestCommitment;
use crate::sign::{AnySignature, AnySigningKey, AnyVerifyingKey, SignatureScheme};
use crate::util::iter::TransposeResultIterExt;
use crate::util::serde::Base64;
pub const AUTH_SIG_HEADER: &str = "X-StartOS-Auth-Sig";
pub trait SignatureAuthContext: Context {
type Database: HasModel<Model = Model<Self::Database>> + Send + Sync;
type AdditionalMetadata: DeserializeOwned + Send;
@@ -41,7 +46,82 @@ pub trait SignatureAuthContext: Context {
) -> impl Future<Output = Result<(), Error>> + Send;
}
pub const AUTH_SIG_HEADER: &str = "X-StartOS-Auth-Sig";
impl SignatureAuthContext for RpcContext {
type Database = Database;
type AdditionalMetadata = ();
type CheckPubkeyRes = ();
fn db(&self) -> &TypedPatchDb<Self::Database> {
&self.db
}
async fn sig_context(
&self,
) -> impl IntoIterator<Item = Result<impl AsRef<str> + Send, Error>> + Send {
let peek = self.db.peek().await;
self.account.peek(|a| {
a.hostnames()
.into_iter()
.map(Ok)
.chain(
peek.as_public()
.as_server_info()
.as_network()
.as_host()
.as_public_domains()
.keys()
.map(|k| k.into_iter())
.transpose(),
)
.chain(
peek.as_public()
.as_server_info()
.as_network()
.as_host()
.as_private_domains()
.de()
.map(|k| k.into_iter())
.transpose(),
)
.collect::<Vec<_>>()
})
}
fn check_pubkey(
db: &Model<Self::Database>,
pubkey: Option<&AnyVerifyingKey>,
_: Self::AdditionalMetadata,
) -> Result<Self::CheckPubkeyRes, Error> {
if let Some(pubkey) = pubkey {
if db.as_private().as_auth_pubkeys().de()?.contains(pubkey) {
return Ok(());
}
}
Err(Error::new(
eyre!("Developer Key is not authorized"),
ErrorKind::IncorrectPassword,
))
}
async fn post_auth_hook(&self, _: Self::CheckPubkeyRes, _: &RpcRequest) -> Result<(), Error> {
Ok(())
}
}
pub trait SigningContext {
fn signing_key(&self) -> Result<AnySigningKey, Error>;
}
impl SigningContext for CliContext {
fn signing_key(&self) -> Result<AnySigningKey, Error> {
Ok(AnySigningKey::Ed25519(self.developer_key()?.clone()))
}
}
impl SigningContext for RpcContext {
fn signing_key(&self) -> Result<AnySigningKey, Error> {
Ok(AnySigningKey::Ed25519(
self.account.peek(|a| a.developer_key.clone()),
))
}
}
#[derive(Deserialize)]
pub struct Metadata<Additional> {
@@ -203,7 +283,7 @@ impl<C: SignatureAuthContext> Middleware<C> for SignatureAuth {
let signer = self.signer.take().transpose()?;
if metadata.get_signer {
if let Some(signer) = &signer {
request.params["__auth_signer"] = to_value(signer)?;
request.params["__Auth_signer"] = to_value(signer)?;
}
}
let db = context.db().peek().await;
@@ -216,10 +296,11 @@ impl<C: SignatureAuthContext> Middleware<C> for SignatureAuth {
}
}
pub async fn call_remote(
ctx: &CliContext,
pub async fn call_remote<Ctx: SigningContext + AsRef<Client>>(
ctx: &Ctx,
url: Url,
sig_context: &str,
headers: HeaderMap,
sig_context: Option<&str>,
method: &str,
params: Value,
) -> Result<Value, RpcError> {
@@ -235,16 +316,16 @@ pub async fn call_remote(
};
let body = serde_json::to_vec(&rpc_req)?;
let mut req = ctx
.client
.as_ref()
.request(Method::POST, url)
.header(CONTENT_TYPE, "application/json")
.header(ACCEPT, "application/json")
.header(CONTENT_LENGTH, body.len());
if let Ok(key) = ctx.developer_key() {
.header(CONTENT_LENGTH, body.len())
.headers(headers);
if let (Some(sig_ctx), Ok(key)) = (sig_context, ctx.signing_key()) {
req = req.header(
AUTH_SIG_HEADER,
SignatureHeader::sign(&AnySigningKey::Ed25519(key.clone()), &body, sig_context)?
.to_header(),
SignatureHeader::sign(&key, &body, sig_ctx)?.to_header(),
);
}
let res = req.body(body).send().await?;

View File

@@ -1,9 +1,12 @@
use std::collections::{BTreeMap, BTreeSet};
use std::net::IpAddr;
use std::str::FromStr;
use std::sync::Arc;
use async_acme::acme::Identifier;
use async_acme::acme::{ACME_TLS_ALPN_NAME, Identifier};
use clap::Parser;
use clap::builder::ValueParserFactory;
use futures::StreamExt;
use imbl_value::InternedString;
use itertools::Itertools;
use models::{ErrorData, FromStrParser};
@@ -11,14 +14,209 @@ use openssl::pkey::{PKey, Private};
use openssl::x509::X509;
use rpc_toolkit::{Context, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::rustls::crypto::CryptoProvider;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio_rustls::rustls::server::ClientHello;
use tokio_rustls::rustls::sign::CertifiedKey;
use ts_rs::TS;
use url::Url;
use crate::context::{CliContext, RpcContext};
use crate::db::model::Database;
use crate::db::model::public::AcmeSettings;
use crate::db::{DbAccess, DbAccessByKey, DbAccessMut};
use crate::net::tls::{SingleCertResolver, TlsHandler};
use crate::net::web_server::Accept;
use crate::prelude::*;
use crate::util::serde::{Pem, Pkcs8Doc};
use crate::util::sync::{SyncMutex, Watch};
pub type AcmeTlsAlpnCache =
Arc<SyncMutex<BTreeMap<InternedString, Watch<Option<Arc<CertifiedKey>>>>>>;
pub struct AcmeTlsHandler<M: HasModel, S> {
pub db: TypedPatchDb<M>,
pub acme_cache: AcmeTlsAlpnCache,
pub crypto_provider: Arc<CryptoProvider>,
pub get_provider: S,
pub in_progress: Watch<BTreeSet<BTreeSet<InternedString>>>,
}
impl<M, S> AcmeTlsHandler<M, S>
where
for<'a> M: DbAccessByKey<AcmeSettings, Key<'a> = &'a AcmeProvider>
+ DbAccessMut<AcmeCertStore>
+ HasModel<Model = Model<M>>
+ Send
+ Sync,
S: GetAcmeProvider + Clone,
{
pub async fn get_cert(&self, san_info: &BTreeSet<InternedString>) -> Option<CertifiedKey> {
let provider = self.get_provider.get_provider(san_info).await?;
let provider = provider.as_ref();
loop {
let peek = self.db.peek().await;
let store = <M as DbAccess<AcmeCertStore>>::access(&peek);
if let Some(cert) = store
.as_certs()
.as_idx(&provider.0)
.and_then(|p| p.as_idx(JsonKey::new_ref(san_info)))
{
let cert = cert.de().log_err()?;
return Some(
CertifiedKey::from_der(
cert.fullchain
.into_iter()
.map(|c| Ok(CertificateDer::from(c.to_der()?)))
.collect::<Result<_, Error>>()
.log_err()?,
PrivateKeyDer::from(PrivatePkcs8KeyDer::from(
cert.key.0.private_key_to_pkcs8().log_err()?,
)),
&*self.crypto_provider,
)
.log_err()?,
);
}
if !self.in_progress.send_if_modified(|x| {
if !x.contains(san_info) {
x.insert(san_info.clone());
true
} else {
false
}
}) {
self.in_progress
.clone()
.wait_for(|x| !x.contains(san_info))
.await;
continue;
}
let contact = <M as DbAccessByKey<AcmeSettings>>::access_by_key(&peek, &provider)?
.as_contact()
.de()
.log_err()?;
let identifiers: Vec<_> = san_info
.iter()
.map(|d| match d.parse::<IpAddr>() {
Ok(a) => Identifier::Ip(a),
_ => Identifier::Dns((&**d).into()),
})
.collect::<Vec<_>>();
let cache_entries = san_info
.iter()
.cloned()
.map(|d| (d, Watch::new(None)))
.collect::<BTreeMap<_, _>>();
self.acme_cache.mutate(|c| {
c.extend(cache_entries.iter().map(|(k, v)| (k.clone(), v.clone())));
});
let cert = async_acme::rustls_helper::order(
|identifier, cert| {
let domain = InternedString::from_display(&identifier);
if let Some(entry) = cache_entries.get(&domain) {
entry.send(Some(Arc::new(cert)));
}
Ok(())
},
provider.0.as_str(),
&identifiers,
Some(&AcmeCertCache(&self.db)),
&contact,
)
.await
.log_err()?;
self.acme_cache
.mutate(|c| c.retain(|c, _| !cache_entries.contains_key(c)));
self.in_progress.send_modify(|i| i.remove(san_info));
return Some(cert);
}
}
}
pub trait GetAcmeProvider {
fn get_provider<'a, 'b: 'a>(
&'b self,
san_info: &'a BTreeSet<InternedString>,
) -> impl Future<Output = Option<impl AsRef<AcmeProvider> + Send + 'b>> + Send + 'a;
}
impl<'a, A, M, S> TlsHandler<'a, A> for Arc<AcmeTlsHandler<M, S>>
where
A: Accept + 'a,
<A as Accept>::Metadata: Send + Sync,
for<'m> M: DbAccessByKey<AcmeSettings, Key<'m> = &'m AcmeProvider>
+ DbAccessMut<AcmeCertStore>
+ HasModel<Model = Model<M>>
+ Send
+ Sync,
S: GetAcmeProvider + Clone + Send + Sync,
{
async fn get_config(
&'a mut self,
hello: &'a ClientHello<'a>,
_: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig> {
let domain = hello.server_name()?;
if hello
.alpn()
.into_iter()
.flatten()
.any(|a| a == ACME_TLS_ALPN_NAME)
{
let cert = self
.acme_cache
.peek(|c| c.get(domain).cloned())
.ok_or_else(|| {
Error::new(
eyre!("No challenge recv available for {domain}"),
ErrorKind::OpenSsl,
)
})
.log_err()?;
tracing::info!("Waiting for verification cert for {domain}");
let cert = cert
.filter(|c| futures::future::ready(c.is_some()))
.next()
.await
.flatten()?;
tracing::info!("Verification cert received for {domain}");
let mut cfg = ServerConfig::builder_with_provider(self.crypto_provider.clone())
.with_safe_default_protocol_versions()
.log_err()?
.with_no_client_auth()
.with_cert_resolver(Arc::new(SingleCertResolver(cert)));
cfg.alpn_protocols = vec![ACME_TLS_ALPN_NAME.to_vec()];
tracing::info!("performing ACME auth challenge");
return Some(cfg);
}
let domains: BTreeSet<InternedString> = [domain.into()].into_iter().collect();
let crypto_provider = self.crypto_provider.clone();
if let Some(cert) = self.get_cert(&domains).await {
return Some(
ServerConfig::builder_with_provider(crypto_provider)
.with_safe_default_protocol_versions()
.log_err()?
.with_no_client_auth()
.with_cert_resolver(Arc::new(SingleCertResolver(Arc::new(cert)))),
);
}
None
}
}
#[derive(Debug, Default, Deserialize, Serialize, HasModel)]
#[model = "Model<Self>"]
@@ -32,29 +230,35 @@ impl AcmeCertStore {
}
}
impl DbAccess<AcmeCertStore> for Database {
fn access<'a>(db: &'a Model<Self>) -> &'a Model<AcmeCertStore> {
db.as_private().as_key_store().as_acme()
}
}
impl DbAccessMut<AcmeCertStore> for Database {
fn access_mut<'a>(db: &'a mut Model<Self>) -> &'a mut Model<AcmeCertStore> {
db.as_private_mut().as_key_store_mut().as_acme_mut()
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct AcmeCert {
pub key: Pem<PKey<Private>>,
pub fullchain: Vec<Pem<X509>>,
}
pub struct AcmeCertCache<'a>(pub &'a TypedPatchDb<Database>);
pub struct AcmeCertCache<'a, M: HasModel>(pub &'a TypedPatchDb<M>);
#[async_trait::async_trait]
impl<'a> async_acme::cache::AcmeCache for AcmeCertCache<'a> {
impl<'a, M> async_acme::cache::AcmeCache for AcmeCertCache<'a, M>
where
M: HasModel<Model = Model<M>> + DbAccessMut<AcmeCertStore> + Send + Sync,
{
type Error = ErrorData;
async fn read_account(&self, contacts: &[&str]) -> Result<Option<Vec<u8>>, Self::Error> {
let contacts = JsonKey::new(contacts.into_iter().map(|s| (*s).to_owned()).collect_vec());
let Some(account) = self
.0
.peek()
.await
.into_private()
.into_key_store()
.into_acme()
.into_accounts()
.into_idx(&contacts)
else {
let peek = self.0.peek().await;
let Some(account) = M::access(&peek).as_accounts().as_idx(&contacts) else {
return Ok(None);
};
Ok(Some(account.de()?.0.document.into_vec()))
@@ -68,9 +272,7 @@ impl<'a> async_acme::cache::AcmeCache for AcmeCertCache<'a> {
};
self.0
.mutate(|db| {
db.as_private_mut()
.as_key_store_mut()
.as_acme_mut()
M::access_mut(db)
.as_accounts_mut()
.insert(&contacts, &Pem::new(key))
})
@@ -96,16 +298,11 @@ impl<'a> async_acme::cache::AcmeCache for AcmeCertCache<'a> {
let directory_url = directory_url
.parse::<Url>()
.with_kind(ErrorKind::ParseUrl)?;
let Some(cert) = self
.0
.peek()
.await
.into_private()
.into_key_store()
.into_acme()
.into_certs()
.into_idx(&directory_url)
.and_then(|a| a.into_idx(&identifiers))
let peek = self.0.peek().await;
let Some(cert) = M::access(&peek)
.as_certs()
.as_idx(&directory_url)
.and_then(|a| a.as_idx(&identifiers))
else {
return Ok(None);
};
@@ -160,9 +357,7 @@ impl<'a> async_acme::cache::AcmeCache for AcmeCertCache<'a> {
};
self.0
.mutate(|db| {
db.as_private_mut()
.as_key_store_mut()
.as_acme_mut()
M::access_mut(db)
.as_certs_mut()
.upsert(&directory_url, || Ok(BTreeMap::new()))?
.insert(&identifiers, &cert)
@@ -235,6 +430,11 @@ impl AsRef<str> for AcmeProvider {
self.0.as_str()
}
}
impl AsRef<AcmeProvider> for AcmeProvider {
fn as_ref(&self) -> &AcmeProvider {
self
}
}
impl ValueParserFactory for AcmeProvider {
type Parser = FromStrParser<Self>;
fn value_parser() -> Self::Parser {

View File

@@ -11,36 +11,36 @@ use futures::future::BoxFuture;
use futures::{FutureExt, StreamExt};
use helpers::NonDetachingJoinHandle;
use hickory_client::client::Client;
use hickory_client::proto::DnsHandle;
use hickory_client::proto::runtime::TokioRuntimeProvider;
use hickory_client::proto::tcp::TcpClientStream;
use hickory_client::proto::udp::UdpClientStream;
use hickory_client::proto::xfer::DnsRequestOptions;
use hickory_client::proto::DnsHandle;
use hickory_server::ServerFuture;
use hickory_server::authority::MessageResponseBuilder;
use hickory_server::proto::op::{Header, ResponseCode};
use hickory_server::proto::rr::{Name, Record, RecordType};
use hickory_server::server::{Request, RequestHandler, ResponseHandler, ResponseInfo};
use hickory_server::ServerFuture;
use imbl::OrdMap;
use imbl_value::InternedString;
use itertools::Itertools;
use models::{GatewayId, OptionExt, PackageId};
use patch_db::json_ptr::JsonPointer;
use rpc_toolkit::{
from_fn_async, from_fn_blocking, Context, HandlerArgs, HandlerExt, ParentHandler,
Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async, from_fn_blocking,
};
use serde::{Deserialize, Serialize};
use tokio::net::{TcpListener, UdpSocket};
use tracing::instrument;
use crate::context::RpcContext;
use crate::db::model::public::NetworkInterfaceInfo;
use crate::context::{CliContext, RpcContext};
use crate::db::model::Database;
use crate::db::model::public::NetworkInterfaceInfo;
use crate::net::gateway::NetworkInterfaceWatcher;
use crate::prelude::*;
use crate::util::actor::background::BackgroundJobQueue;
use crate::util::io::file_string_stream;
use crate::util::serde::{display_serializable, HandlerExtSerde};
use crate::util::serde::{HandlerExtSerde, display_serializable};
use crate::util::sync::{SyncRwLock, Watch};
pub fn dns_api<C: Context>() -> ParentHandler<C> {
@@ -66,7 +66,36 @@ pub fn dns_api<C: Context>() -> ParentHandler<C> {
"set-static",
from_fn_async(set_static_dns)
.no_display()
.with_about("Set static DNS servers"),
.with_about("Set static DNS servers")
.with_call_remote::<CliContext>(),
)
.subcommand(
"dump-table",
from_fn_async(dump_table)
.with_display_serializable()
.with_custom_display_fn(|HandlerArgs { params, .. }, res| {
use prettytable::*;
if let Some(format) = params.format {
return display_serializable(format, res);
}
let mut table = Table::new();
table.add_row(row![bc => "FQDN", "DESTINATION"]);
for (hostname, destination) in res {
if let Some(ip) = destination {
table.add_row(row![hostname, ip]);
} else {
table.add_row(row![hostname, "SELF"]);
}
}
table.print_tty(false)?;
Ok(())
})
.with_about("Dump address resolution table")
.with_call_remote::<CliContext>(),
)
}
@@ -142,6 +171,38 @@ pub async fn set_static_dns(
.result
}
pub async fn dump_table(
ctx: RpcContext,
) -> Result<BTreeMap<InternedString, Option<IpAddr>>, Error> {
Ok(ctx
.net_controller
.dns
.resolve
.upgrade()
.or_not_found("DnsController")?
.peek(|map| {
map.private_domains
.iter()
.map(|(d, _)| (d.clone(), None))
.chain(map.services.iter().filter_map(|(svc, ip)| {
ip.iter()
.find(|(_, rc)| rc.strong_count() > 0)
.map(|(ip, _)| {
(
svc.as_ref().map_or(
InternedString::from_static("startos"),
|svc| {
InternedString::from_display(&lazy_format!("{svc}.startos"))
},
),
Some(IpAddr::V4(*ip)),
)
})
}))
.collect()
}))
}
#[derive(Default)]
struct ResolveMap {
private_domains: BTreeMap<InternedString, Weak<()>>,
@@ -222,9 +283,9 @@ impl DnsClient {
});
loop {
if let Err::<(), Error>(e) = async {
let mut static_changed = db
let mut dns_changed = db
.subscribe(
"/public/serverInfo/network/dns/staticServers"
"/public/serverInfo/network/dns"
.parse::<JsonPointer>()
.with_kind(ErrorKind::Database)?,
)
@@ -275,7 +336,7 @@ impl DnsClient {
Client::new(stream, sender, None)
.await
.with_kind(ErrorKind::Network)?;
bg.insert(*addr, bg_thread.boxed());
bg.insert(*addr, bg_thread.fuse().boxed());
client
};
new.push((*addr, client));
@@ -286,7 +347,7 @@ impl DnsClient {
client.replace(new);
}
futures::future::select(
static_changed.recv().boxed(),
dns_changed.recv().boxed(),
futures::future::join(
futures::future::join_all(bg.values_mut()),
futures::future::pending::<()>(),
@@ -333,10 +394,20 @@ struct Resolver {
resolve: Arc<SyncRwLock<ResolveMap>>,
}
impl Resolver {
fn resolve(&self, name: &Name, src: IpAddr) -> Option<Vec<IpAddr>> {
fn resolve(&self, name: &Name, mut src: IpAddr) -> Option<Vec<IpAddr>> {
if name.zone_of(&*LOCALHOST) {
return Some(vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()]);
}
src = match src {
IpAddr::V6(v6) => {
if let Some(v4) = v6.to_ipv4_mapped() {
IpAddr::V4(v4)
} else {
IpAddr::V6(v6)
}
}
a => a,
};
self.resolve.peek(|r| {
if r.private_domains
.get(&*name.to_lowercase().to_utf8().trim_end_matches('.'))
@@ -344,8 +415,7 @@ impl Resolver {
{
if let Some(res) = self.net_iface.peek(|i| {
i.values()
.chain([NetworkInterfaceInfo::lxc_bridge().1])
.flat_map(|i| i.ip_info.as_ref())
.filter_map(|i| i.ip_info.as_ref())
.find(|i| i.subnets.iter().any(|s| s.contains(&src)))
.map(|ip_info| {
let mut res = ip_info.subnets.iter().collect::<Vec<_>>();
@@ -354,6 +424,8 @@ impl Resolver {
})
}) {
return Some(res);
} else {
tracing::warn!("Could not determine source interface of {src}");
}
}
if STARTOS.zone_of(name) || EMBASSY.zone_of(name) {
@@ -406,11 +478,7 @@ impl RequestHandler for Resolver {
header,
&ip.into_iter()
.filter_map(|a| {
if let IpAddr::V4(a) = a {
Some(a)
} else {
None
}
if let IpAddr::V4(a) = a { Some(a) } else { None }
})
.map(|ip| {
Record::from_rdata(
@@ -436,11 +504,7 @@ impl RequestHandler for Resolver {
header,
&ip.into_iter()
.filter_map(|a| {
if let IpAddr::V6(a) = a {
Some(a)
} else {
None
}
if let IpAddr::V6(a) = a { Some(a) } else { None }
})
.map(|ip| {
Record::from_rdata(

View File

@@ -1,25 +1,26 @@
use std::collections::{BTreeMap, BTreeSet};
use std::net::{IpAddr, SocketAddr, SocketAddrV6};
use std::net::{IpAddr, SocketAddrV4};
use std::sync::{Arc, Weak};
use std::time::Duration;
use futures::channel::oneshot;
use helpers::NonDetachingJoinHandle;
use id_pool::IdPool;
use iddqd::{IdOrdItem, IdOrdMap};
use imbl::OrdMap;
use models::GatewayId;
use rpc_toolkit::{from_fn_async, Context, HandlerArgs, HandlerExt, ParentHandler};
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use tokio::sync::mpsc;
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::NetworkInterfaceInfo;
use crate::net::gateway::{DynInterfaceFilter, InterfaceFilter, SecureFilter};
use crate::net::utils::ipv6_is_link_local;
use crate::net::gateway::{DynInterfaceFilter, InterfaceFilter};
use crate::prelude::*;
use crate::util::serde::{display_serializable, HandlerExtSerde};
use crate::util::sync::Watch;
use crate::util::Invoke;
use crate::util::serde::{HandlerExtSerde, display_serializable};
use crate::util::sync::Watch;
pub const START9_BRIDGE_IFACE: &str = "lxcbr0";
pub const FIRST_DYNAMIC_PRIVATE_PORT: u16 = 49152;
@@ -60,17 +61,10 @@ pub fn forward_api<C: Context>() -> ParentHandler<C> {
}
let mut table = Table::new();
table.add_row(row![bc => "FROM", "TO", "FILTER / GATEWAY"]);
table.add_row(row![bc => "FROM", "TO", "FILTER"]);
for (external, target) in res.0 {
table.add_row(row![external, target.target, target.filter]);
for (source, gateway) in target.gateways {
table.add_row(row![
format!("{}:{}", source, external),
target.target,
gateway
]);
}
}
table.print_tty(false)?;
@@ -81,162 +75,368 @@ pub fn forward_api<C: Context>() -> ParentHandler<C> {
)
}
struct ForwardRequest {
external: u16,
target: SocketAddr,
filter: DynInterfaceFilter,
struct ForwardMapping {
source: SocketAddrV4,
target: SocketAddrV4,
rc: Weak<()>,
}
#[derive(Clone)]
struct ForwardEntry {
external: u16,
target: SocketAddr,
prev_filter: DynInterfaceFilter,
forwards: BTreeMap<SocketAddr, GatewayId>,
rc: Weak<()>,
#[derive(Default)]
struct PortForwardState {
mappings: BTreeMap<SocketAddrV4, ForwardMapping>, // source -> target
}
impl ForwardEntry {
fn new(external: u16, target: SocketAddr, rc: Weak<()>) -> Self {
Self {
external,
target,
prev_filter: false.into_dyn(),
forwards: BTreeMap::new(),
rc,
}
}
fn take(&mut self) -> Self {
Self {
external: self.external,
target: self.target,
prev_filter: std::mem::replace(&mut self.prev_filter, false.into_dyn()),
forwards: std::mem::take(&mut self.forwards),
rc: self.rc.clone(),
}
}
async fn destroy(mut self) -> Result<(), Error> {
while let Some((source, interface)) = self.forwards.pop_first() {
unforward(interface.as_str(), source, self.target).await?;
}
Ok(())
}
async fn update(
impl PortForwardState {
async fn add_forward(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
filter: Option<DynInterfaceFilter>,
) -> Result<(), Error> {
if self.rc.strong_count() == 0 {
return self.take().destroy().await;
}
let filter_ref = filter.as_ref().unwrap_or(&self.prev_filter);
let mut keep = BTreeSet::<SocketAddr>::new();
for (iface, info) in ip_info
.iter()
// .chain([NetworkInterfaceInfo::loopback()])
.filter(|(id, info)| filter_ref.filter(*id, *info))
{
if let Some(ip_info) = &info.ip_info {
for ipnet in &ip_info.subnets {
let addr = match ipnet.addr() {
IpAddr::V6(ip6) => SocketAddrV6::new(
ip6,
self.external,
0,
if ipv6_is_link_local(ip6) {
ip_info.scope_id
} else {
0
},
)
.into(),
ip => SocketAddr::new(ip, self.external),
};
keep.insert(addr);
if !self.forwards.contains_key(&addr) {
forward(iface.as_str(), addr, self.target).await?;
self.forwards.insert(addr, iface.clone());
}
source: SocketAddrV4,
target: SocketAddrV4,
) -> Result<Arc<()>, Error> {
if let Some(existing) = self.mappings.get_mut(&source) {
if existing.target == target {
if let Some(existing_rc) = existing.rc.upgrade() {
return Ok(existing_rc);
} else {
let rc = Arc::new(());
existing.rc = Arc::downgrade(&rc);
return Ok(rc);
}
} else {
// Different target, need to remove old and add new
if let Some(mapping) = self.mappings.remove(&source) {
unforward(mapping.source, mapping.target).await?;
}
}
}
let rm = self
.forwards
.keys()
.copied()
.filter(|a| !keep.contains(a))
.collect::<Vec<_>>();
for rm in rm {
if let Some((source, interface)) = self.forwards.remove_entry(&rm) {
unforward(interface.as_str(), source, self.target).await?;
let rc = Arc::new(());
forward(source, target).await?;
self.mappings.insert(
source,
ForwardMapping {
source,
target,
rc: Arc::downgrade(&rc),
},
);
Ok(rc)
}
async fn gc(&mut self) -> Result<(), Error> {
let to_remove: Vec<SocketAddrV4> = self
.mappings
.iter()
.filter(|(_, mapping)| mapping.rc.strong_count() == 0)
.map(|(source, _)| *source)
.collect();
for source in to_remove {
if let Some(mapping) = self.mappings.remove(&source) {
unforward(mapping.source, mapping.target).await?;
}
}
if let Some(filter) = filter {
self.prev_filter = filter;
}
Ok(())
}
async fn update_request(
&mut self,
ForwardRequest {
external,
target,
filter,
rc,
}: ForwardRequest,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
) -> Result<(), Error> {
if external != self.external || target != self.target {
self.take().destroy().await?;
*self = Self::new(external, target, rc);
self.update(ip_info, Some(filter)).await?;
} else {
self.rc = rc;
self.update(ip_info, Some(filter).filter(|f| f != &self.prev_filter))
.await?;
}
Ok(())
fn dump(&self) -> BTreeMap<SocketAddrV4, SocketAddrV4> {
self.mappings
.iter()
.filter(|(_, mapping)| mapping.rc.strong_count() > 0)
.map(|(source, mapping)| (*source, mapping.target))
.collect()
}
}
impl Drop for ForwardEntry {
impl Drop for PortForwardState {
fn drop(&mut self) {
if !self.forwards.is_empty() {
let take = self.take();
if !self.mappings.is_empty() {
let mappings = std::mem::take(&mut self.mappings);
tokio::spawn(async move {
take.destroy().await.log_err();
for (_, mapping) in mappings {
unforward(mapping.source, mapping.target).await.log_err();
}
});
}
}
}
#[derive(Default, Clone)]
struct ForwardState {
state: BTreeMap<u16, ForwardEntry>,
enum PortForwardCommand {
AddForward {
source: SocketAddrV4,
target: SocketAddrV4,
respond: oneshot::Sender<Result<Arc<()>, Error>>,
},
Gc {
respond: oneshot::Sender<Result<(), Error>>,
},
Dump {
respond: oneshot::Sender<BTreeMap<SocketAddrV4, SocketAddrV4>>,
},
}
impl ForwardState {
pub struct PortForwardController {
req: mpsc::UnboundedSender<PortForwardCommand>,
_thread: NonDetachingJoinHandle<()>,
}
impl PortForwardController {
pub fn new() -> Self {
let (req_send, mut req_recv) = mpsc::unbounded_channel::<PortForwardCommand>();
let thread = NonDetachingJoinHandle::from(tokio::spawn(async move {
while let Err(e) = async {
Command::new("sysctl")
.arg("-w")
.arg("net.ipv4.ip_forward=1")
.invoke(ErrorKind::Network)
.await?;
if Command::new("iptables")
.arg("-t")
.arg("nat")
.arg("-C")
.arg("POSTROUTING")
.arg("-j")
.arg("MASQUERADE")
.invoke(ErrorKind::Network)
.await
.is_err()
{
Command::new("iptables")
.arg("-t")
.arg("nat")
.arg("-A")
.arg("POSTROUTING")
.arg("-j")
.arg("MASQUERADE")
.invoke(ErrorKind::Network)
.await?;
}
Ok::<_, Error>(())
}
.await
{
tracing::error!("error initializing PortForwardController: {e:#}");
tracing::debug!("{e:?}");
tokio::time::sleep(Duration::from_secs(5)).await;
}
let mut state = PortForwardState::default();
while let Some(cmd) = req_recv.recv().await {
match cmd {
PortForwardCommand::AddForward {
source,
target,
respond,
} => {
let result = state.add_forward(source, target).await;
respond.send(result).ok();
}
PortForwardCommand::Gc { respond } => {
let result = state.gc().await;
respond.send(result).ok();
}
PortForwardCommand::Dump { respond } => {
respond.send(state.dump()).ok();
}
}
}
}));
Self {
req: req_send,
_thread: thread,
}
}
pub async fn add_forward(
&self,
source: SocketAddrV4,
target: SocketAddrV4,
) -> Result<Arc<()>, Error> {
let (send, recv) = oneshot::channel();
self.req
.send(PortForwardCommand::AddForward {
source,
target,
respond: send,
})
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)?
}
pub async fn gc(&self) -> Result<(), Error> {
let (send, recv) = oneshot::channel();
self.req
.send(PortForwardCommand::Gc { respond: send })
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)?
}
pub async fn dump(&self) -> Result<BTreeMap<SocketAddrV4, SocketAddrV4>, Error> {
let (send, recv) = oneshot::channel();
self.req
.send(PortForwardCommand::Dump { respond: send })
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)
}
}
struct InterfaceForwardRequest {
external: u16,
target: SocketAddrV4,
filter: DynInterfaceFilter,
rc: Arc<()>,
}
#[derive(Clone)]
struct InterfaceForwardEntry {
external: u16,
filter: BTreeMap<DynInterfaceFilter, (SocketAddrV4, Weak<()>)>,
// Maps source SocketAddr -> strong reference for the forward created in PortForwardController
forwards: BTreeMap<SocketAddrV4, Arc<()>>,
}
impl IdOrdItem for InterfaceForwardEntry {
type Key<'a> = u16;
fn key(&self) -> Self::Key<'_> {
self.external
}
iddqd::id_upcast!();
}
impl InterfaceForwardEntry {
fn new(external: u16) -> Self {
Self {
external,
filter: BTreeMap::new(),
forwards: BTreeMap::new(),
}
}
async fn update(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
port_forward: &PortForwardController,
) -> Result<(), Error> {
let mut keep = BTreeSet::<SocketAddrV4>::new();
for (iface, info) in ip_info.iter() {
if let Some(target) = self
.filter
.iter()
.filter(|(_, (_, rc))| rc.strong_count() > 0)
.find(|(filter, _)| filter.filter(iface, info))
.map(|(_, (target, _))| *target)
{
if let Some(ip_info) = &info.ip_info {
for addr in ip_info.subnets.iter().filter_map(|net| {
if let IpAddr::V4(ip) = net.addr() {
Some(SocketAddrV4::new(ip, self.external))
} else {
None
}
}) {
keep.insert(addr);
if !self.forwards.contains_key(&addr) {
let rc = port_forward.add_forward(addr, target).await?;
self.forwards.insert(addr, rc);
}
}
}
}
}
// Remove forwards that should no longer exist (drops the strong references)
self.forwards.retain(|addr, _| keep.contains(addr));
Ok(())
}
async fn update_request(
&mut self,
InterfaceForwardRequest {
external,
target,
filter,
mut rc,
}: InterfaceForwardRequest,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
port_forward: &PortForwardController,
) -> Result<Arc<()>, Error> {
if external != self.external {
return Err(Error::new(
eyre!("Mismatched external port in InterfaceForwardEntry"),
ErrorKind::InvalidRequest,
));
}
let entry = self
.filter
.entry(filter)
.or_insert_with(|| (target, Arc::downgrade(&rc)));
if entry.0 != target {
entry.0 = target;
entry.1 = Arc::downgrade(&rc);
}
if let Some(existing) = entry.1.upgrade() {
rc = existing;
} else {
entry.1 = Arc::downgrade(&rc);
}
self.update(ip_info, port_forward).await?;
Ok(rc)
}
async fn gc(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
port_forward: &PortForwardController,
) -> Result<(), Error> {
self.filter.retain(|_, (_, rc)| rc.strong_count() > 0);
self.update(ip_info, port_forward).await
}
}
struct InterfaceForwardState {
port_forward: PortForwardController,
state: IdOrdMap<InterfaceForwardEntry>,
}
impl InterfaceForwardState {
fn new(port_forward: PortForwardController) -> Self {
Self {
port_forward,
state: IdOrdMap::new(),
}
}
}
impl InterfaceForwardState {
async fn handle_request(
&mut self,
request: ForwardRequest,
request: InterfaceForwardRequest,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
) -> Result<(), Error> {
) -> Result<Arc<()>, Error> {
self.state
.entry(request.external)
.or_insert_with(|| ForwardEntry::new(request.external, request.target, Weak::new()))
.update_request(request, ip_info)
.or_insert_with(|| InterfaceForwardEntry::new(request.external))
.update_request(request, ip_info, &self.port_forward)
.await
}
async fn sync(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
) -> Result<(), Error> {
for entry in self.state.values_mut() {
entry.update(ip_info, None).await?;
for mut entry in self.state.iter_mut() {
entry.gc(ip_info, &self.port_forward).await?;
}
self.state.retain(|_, fwd| fwd.rc.strong_count() > 0);
Ok(())
}
}
@@ -250,65 +450,85 @@ fn err_has_exited<T>(_: T) -> Error {
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ForwardTable(pub BTreeMap<u16, ForwardTarget>);
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ForwardTarget {
pub target: SocketAddr,
pub target: SocketAddrV4,
pub filter: String,
pub gateways: BTreeMap<SocketAddr, GatewayId>,
}
impl From<&ForwardState> for ForwardTable {
fn from(value: &ForwardState) -> Self {
impl From<&InterfaceForwardState> for ForwardTable {
fn from(value: &InterfaceForwardState) -> Self {
Self(
value
.state
.iter()
.map(|(external, entry)| {
(
*external,
ForwardTarget {
target: entry.target,
filter: format!("{:?}", entry.prev_filter),
gateways: entry.forwards.clone(),
},
)
.flat_map(|entry| {
entry
.filter
.iter()
.filter(|(_, (_, rc))| rc.strong_count() > 0)
.map(|(filter, (target, _))| {
(
entry.external,
ForwardTarget {
target: *target,
filter: format!("{:?}", filter),
},
)
})
})
.collect(),
)
}
}
enum ForwardCommand {
Forward(ForwardRequest, oneshot::Sender<Result<(), Error>>),
enum InterfaceForwardCommand {
Forward(
InterfaceForwardRequest,
oneshot::Sender<Result<Arc<()>, Error>>,
),
Sync(oneshot::Sender<Result<(), Error>>),
DumpTable(oneshot::Sender<ForwardTable>),
}
#[test]
fn test() {
use crate::net::gateway::SecureFilter;
assert_ne!(
false.into_dyn(),
SecureFilter { secure: false }.into_dyn().into_dyn()
);
}
pub struct PortForwardController {
req: mpsc::UnboundedSender<ForwardCommand>,
pub struct InterfacePortForwardController {
req: mpsc::UnboundedSender<InterfaceForwardCommand>,
_thread: NonDetachingJoinHandle<()>,
}
impl PortForwardController {
impl InterfacePortForwardController {
pub fn new(mut ip_info: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>) -> Self {
let (req_send, mut req_recv) = mpsc::unbounded_channel::<ForwardCommand>();
let port_forward = PortForwardController::new();
let (req_send, mut req_recv) = mpsc::unbounded_channel::<InterfaceForwardCommand>();
let thread = NonDetachingJoinHandle::from(tokio::spawn(async move {
let mut state = ForwardState::default();
let mut state = InterfaceForwardState::new(port_forward);
let mut interfaces = ip_info.read_and_mark_seen();
loop {
tokio::select! {
msg = req_recv.recv() => {
if let Some(cmd) = msg {
match cmd {
ForwardCommand::Forward(req, re) => re.send(state.handle_request(req, &interfaces).await).ok(),
ForwardCommand::Sync(re) => re.send(state.sync(&interfaces).await).ok(),
ForwardCommand::DumpTable(re) => re.send((&state).into()).ok(),
InterfaceForwardCommand::Forward(req, re) => {
re.send(state.handle_request(req, &interfaces).await).ok()
}
InterfaceForwardCommand::Sync(re) => {
re.send(state.sync(&interfaces).await).ok()
}
InterfaceForwardCommand::DumpTable(re) => {
re.send((&state).into()).ok()
}
};
} else {
break;
@@ -317,61 +537,61 @@ impl PortForwardController {
_ = ip_info.changed() => {
interfaces = ip_info.read();
state.sync(&interfaces).await.log_err();
state.port_forward.gc().await.log_err();
}
}
}
}));
Self {
req: req_send,
_thread: thread,
}
}
pub async fn add(
&self,
external: u16,
filter: DynInterfaceFilter,
target: SocketAddr,
target: SocketAddrV4,
) -> Result<Arc<()>, Error> {
let rc = Arc::new(());
let (send, recv) = oneshot::channel();
self.req
.send(ForwardCommand::Forward(
ForwardRequest {
.send(InterfaceForwardCommand::Forward(
InterfaceForwardRequest {
external,
target,
filter,
rc: Arc::downgrade(&rc),
rc,
},
send,
))
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)?.map(|_| rc)
recv.await.map_err(err_has_exited)?
}
pub async fn gc(&self) -> Result<(), Error> {
let (send, recv) = oneshot::channel();
self.req
.send(ForwardCommand::Sync(send))
.send(InterfaceForwardCommand::Sync(send))
.map_err(err_has_exited)?;
recv.await.map_err(err_has_exited)?
}
pub async fn dump_table(&self) -> Result<ForwardTable, Error> {
let (req, res) = oneshot::channel();
self.req
.send(ForwardCommand::DumpTable(req))
.send(InterfaceForwardCommand::DumpTable(req))
.map_err(err_has_exited)?;
res.await.map_err(err_has_exited)
}
}
async fn forward(interface: &str, source: SocketAddr, target: SocketAddr) -> Result<(), Error> {
if source.is_ipv6() {
return Ok(()); // TODO: socat? ip6tables?
}
async fn forward(source: SocketAddrV4, target: SocketAddrV4) -> Result<(), Error> {
Command::new("/usr/lib/startos/scripts/forward-port")
.env("iiface", interface)
.env("oiface", START9_BRIDGE_IFACE)
.env("sip", source.ip().to_string())
.env("dip", target.ip().to_string())
.env("sport", source.port().to_string())
@@ -381,14 +601,9 @@ async fn forward(interface: &str, source: SocketAddr, target: SocketAddr) -> Res
Ok(())
}
async fn unforward(interface: &str, source: SocketAddr, target: SocketAddr) -> Result<(), Error> {
if source.is_ipv6() {
return Ok(()); // TODO: socat? ip6tables?
}
async fn unforward(source: SocketAddrV4, target: SocketAddrV4) -> Result<(), Error> {
Command::new("/usr/lib/startos/scripts/forward-port")
.env("UNDO", "1")
.env("iiface", interface)
.env("oiface", START9_BRIDGE_IFACE)
.env("sip", source.ip().to_string())
.env("dip", target.ip().to_string())
.env("sport", source.port().to_string())

View File

@@ -3,10 +3,11 @@ use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::future::Future;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV6};
use std::sync::{Arc, Weak};
use std::task::Poll;
use std::task::{Poll, ready};
use std::time::Duration;
use clap::Parser;
use futures::future::Either;
use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
use helpers::NonDetachingJoinHandle;
use imbl::{OrdMap, OrdSet};
@@ -16,33 +17,34 @@ use itertools::Itertools;
use models::GatewayId;
use nix::net::if_::if_nametoindex;
use patch_db::json_ptr::JsonPointer;
use rpc_toolkit::{from_fn_async, Context, HandlerArgs, HandlerExt, ParentHandler};
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::net::{TcpListener, TcpStream};
use tokio::net::TcpListener;
use tokio::process::Command;
use tokio::sync::oneshot;
use ts_rs::TS;
use visit_rs::{Visit, VisitFields};
use zbus::proxy::{PropertyChanged, PropertyStream, SignalStream};
use zbus::zvariant::{
DeserializeDict, Dict, OwnedObjectPath, OwnedValue, Type as ZType, Value as ZValue,
};
use zbus::{proxy, Connection};
use zbus::{Connection, proxy};
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::{IpInfo, NetworkInterfaceInfo, NetworkInterfaceType};
use crate::db::model::Database;
use crate::db::model::public::{IpInfo, NetworkInterfaceInfo, NetworkInterfaceType};
use crate::net::forward::START9_BRIDGE_IFACE;
use crate::net::gateway::device::DeviceProxy;
use crate::net::utils::ipv6_is_link_local;
use crate::net::web_server::Accept;
use crate::net::web_server::{Accept, AcceptStream, Acceptor, MetadataVisitor};
use crate::prelude::*;
use crate::util::Invoke;
use crate::util::collections::OrdMapIterMut;
use crate::util::future::Until;
use crate::util::io::open_file;
use crate::util::serde::{display_serializable, HandlerExtSerde};
use crate::util::serde::{HandlerExtSerde, display_serializable};
use crate::util::sync::{SyncMutex, Watch};
use crate::util::Invoke;
pub fn gateway_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
@@ -244,6 +246,8 @@ mod active_connection {
default_service = "org.freedesktop.NetworkManager"
)]
trait ConnectionSettings {
fn delete(&self) -> Result<(), Error>;
fn get_settings(&self) -> Result<HashMap<String, HashMap<String, OwnedValue>>, Error>;
fn update2(
@@ -583,15 +587,12 @@ async fn watch_ip(
loop {
until
.run(async {
let external = active_connection_proxy.state_flags().await? & 0x80 != 0;
if external {
return Ok(());
}
let device_type = match device_proxy.device_type().await? {
1 => Some(NetworkInterfaceType::Ethernet),
2 => Some(NetworkInterfaceType::Wireless),
13 => Some(NetworkInterfaceType::Bridge),
29 => Some(NetworkInterfaceType::Wireguard),
32 => Some(NetworkInterfaceType::Loopback),
_ => None,
};
@@ -671,7 +672,14 @@ async fn watch_ip(
.into_iter()
.map(IpNet::try_from)
.try_collect()?;
let wan_ip = if !subnets.is_empty() {
let wan_ip = if !subnets.is_empty()
&& !matches!(
device_type,
Some(
NetworkInterfaceType::Bridge
| NetworkInterfaceType::Loopback
)
) {
match get_wan_ipv4(iface.as_str()).await {
Ok(a) => a,
Err(e) => {
@@ -711,6 +719,7 @@ async fn watch_ip(
)
});
ip_info.wan_ip = ip_info.wan_ip.or(prev_wan_ip);
let ip_info = Arc::new(ip_info);
m.insert(
iface.clone(),
NetworkInterfaceInfo {
@@ -785,13 +794,7 @@ impl NetworkInterfaceWatcher {
watch_activated: impl IntoIterator<Item = GatewayId>,
) -> Self {
let ip_info = Watch::new(OrdMap::new());
let activated = Watch::new(
watch_activated
.into_iter()
.chain([NetworkInterfaceInfo::lxc_bridge().0.clone()])
.map(|k| (k, false))
.collect(),
);
let activated = Watch::new(watch_activated.into_iter().map(|k| (k, false)).collect());
Self {
activated: activated.clone(),
ip_info: ip_info.clone(),
@@ -831,7 +834,7 @@ impl NetworkInterfaceWatcher {
self.ip_info.read()
}
pub fn bind(&self, port: u16) -> Result<NetworkInterfaceListener, Error> {
pub fn bind<B: Bind>(&self, bind: B, port: u16) -> Result<NetworkInterfaceListener<B>, Error> {
let arc = Arc::new(());
self.listeners.mutate(|l| {
if l.get(&port).filter(|w| w.strong_count() > 0).is_some() {
@@ -844,22 +847,20 @@ impl NetworkInterfaceWatcher {
Ok(())
})?;
let ip_info = self.ip_info.clone_unseen();
let activated = self.activated.clone_unseen();
Ok(NetworkInterfaceListener {
_arc: arc,
ip_info,
activated,
listeners: ListenerMap::new(port),
listeners: ListenerMap::new(bind, port),
})
}
pub fn upgrade_listener(
pub fn upgrade_listener<B: Bind>(
&self,
SelfContainedNetworkInterfaceListener {
mut listener,
..
}: SelfContainedNetworkInterfaceListener,
) -> Result<NetworkInterfaceListener, Error> {
}: SelfContainedNetworkInterfaceListener<B>,
) -> Result<NetworkInterfaceListener<B>, Error> {
let port = listener.listeners.port;
let arc = &listener._arc;
self.listeners.mutate(|l| {
@@ -1095,7 +1096,7 @@ impl NetworkInterfaceController {
.ip_info
.peek(|ifaces| ifaces.get(interface).map(|i| i.ip_info.is_some()))
else {
return Ok(());
return self.forget(interface).await;
};
if has_ip_info {
@@ -1115,7 +1116,21 @@ impl NetworkInterfaceController {
let device_proxy = DeviceProxy::new(&connection, device).await?;
device_proxy.delete().await?;
let ac = device_proxy.active_connection().await?;
if &*ac == "/" {
return Err(Error::new(
eyre!("Cannot delete device without active connection"),
ErrorKind::InvalidRequest,
));
}
let ac_proxy = active_connection::ActiveConnectionProxy::new(&connection, ac).await?;
let settings =
ConnectionSettingsProxy::new(&connection, ac_proxy.connection().await?).await?;
settings.delete().await?;
ip_info
.wait_for(|ifaces| ifaces.get(interface).map_or(true, |i| i.ip_info.is_none()))
@@ -1158,45 +1173,6 @@ impl NetworkInterfaceController {
}
}
struct ListenerMap {
prev_filter: DynInterfaceFilter,
port: u16,
listeners: BTreeMap<SocketAddr, (TcpListener, Option<Ipv4Addr>)>,
}
impl ListenerMap {
fn from_listener(listener: impl IntoIterator<Item = TcpListener>) -> Result<Self, Error> {
let mut port = 0;
let mut listeners = BTreeMap::<SocketAddr, (TcpListener, Option<Ipv4Addr>)>::new();
for listener in listener {
let mut local = listener.local_addr().with_kind(ErrorKind::Network)?;
if let SocketAddr::V6(l) = &mut local {
if ipv6_is_link_local(*l.ip()) && l.scope_id() == 0 {
continue; // TODO determine scope id
}
}
if port != 0 && port != local.port() {
return Err(Error::new(
eyre!("Provided listeners are bound to different ports"),
ErrorKind::InvalidRequest,
));
}
port = local.port();
listeners.insert(local, (listener, None));
}
if port == 0 {
return Err(Error::new(
eyre!("Listener array cannot be empty"),
ErrorKind::InvalidRequest,
));
}
Ok(Self {
prev_filter: false.into_dyn(),
port,
listeners,
})
}
}
pub trait InterfaceFilter: Any + Clone + std::fmt::Debug + Eq + Ord + Send + Sync {
fn filter(&self, id: &GatewayId, info: &NetworkInterfaceInfo) -> bool;
fn eq(&self, other: &dyn Any) -> bool {
@@ -1224,6 +1200,14 @@ impl InterfaceFilter for bool {
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct TypeFilter(pub NetworkInterfaceType);
impl InterfaceFilter for TypeFilter {
fn filter(&self, _: &GatewayId, info: &NetworkInterfaceInfo) -> bool {
info.ip_info.as_ref().and_then(|i| i.device_type) == Some(self.0)
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct IdFilter(pub GatewayId);
impl InterfaceFilter for IdFilter {
@@ -1355,10 +1339,17 @@ impl Ord for DynInterfaceFilter {
}
}
impl ListenerMap {
fn new(port: u16) -> Self {
struct ListenerMap<B: Bind> {
prev_filter: DynInterfaceFilter,
bind: B,
port: u16,
listeners: BTreeMap<SocketAddr, B::Accept>,
}
impl<B: Bind> ListenerMap<B> {
fn new(bind: B, port: u16) -> Self {
Self {
prev_filter: false.into_dyn(),
bind,
port,
listeners: BTreeMap::new(),
}
@@ -1368,14 +1359,11 @@ impl ListenerMap {
fn update(
&mut self,
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
lxc_bridge: bool,
filter: &impl InterfaceFilter,
) -> Result<(), Error> {
let mut keep = BTreeSet::<SocketAddr>::new();
for (_, info) in ip_info
.iter()
.chain([NetworkInterfaceInfo::loopback()])
.chain(Some(NetworkInterfaceInfo::lxc_bridge()).filter(|_| lxc_bridge))
.filter(|(id, info)| filter.filter(*id, *info))
{
if let Some(ip_info) = &info.ip_info {
@@ -1395,24 +1383,9 @@ impl ListenerMap {
ip => SocketAddr::new(ip, self.port),
};
keep.insert(addr);
if let Some((_, wan_ip)) = self.listeners.get_mut(&addr) {
*wan_ip = info.ip_info.as_ref().and_then(|i| i.wan_ip);
continue;
if !self.listeners.contains_key(&addr) {
self.listeners.insert(addr, self.bind.bind(addr)?);
}
self.listeners.insert(
addr,
(
TcpListener::from_std(
mio::net::TcpListener::bind(addr)
.with_ctx(|_| {
(ErrorKind::Network, lazy_format!("binding to {addr:?}"))
})?
.into(),
)
.with_kind(ErrorKind::Network)?,
info.ip_info.as_ref().and_then(|i| i.wan_ip),
),
);
}
}
}
@@ -1420,24 +1393,13 @@ impl ListenerMap {
self.prev_filter = filter.clone().into_dyn();
Ok(())
}
fn poll_accept(&self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>> {
for (bind_addr, (listener, wan_ip)) in self.listeners.iter() {
if let Poll::Ready((stream, addr)) = listener.poll_accept(cx)? {
if let Err(e) = socket2::SockRef::from(&stream).set_tcp_keepalive(
&socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(900))
.with_interval(Duration::from_secs(60))
.with_retries(5),
) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
return Poll::Ready(Ok(Accepted {
stream,
peer: addr,
wan_ip: *wan_ip,
bind: *bind_addr,
}));
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(SocketAddr, <B::Accept as Accept>::Metadata, AcceptStream), Error>> {
for (addr, listener) in self.listeners.iter_mut() {
if let Poll::Ready((metadata, stream)) = listener.poll_accept(cx)? {
return Poll::Ready(Ok((*addr, metadata, stream)));
}
}
Poll::Pending
@@ -1448,65 +1410,102 @@ pub fn lookup_info_by_addr(
ip_info: &OrdMap<GatewayId, NetworkInterfaceInfo>,
addr: SocketAddr,
) -> Option<(&GatewayId, &NetworkInterfaceInfo)> {
ip_info
.iter()
.chain([
NetworkInterfaceInfo::loopback(),
NetworkInterfaceInfo::lxc_bridge(),
])
.find(|(_, i)| {
i.ip_info
.as_ref()
.map_or(false, |i| i.subnets.iter().any(|i| i.addr() == addr.ip()))
})
ip_info.iter().find(|(_, i)| {
i.ip_info
.as_ref()
.map_or(false, |i| i.subnets.iter().any(|i| i.addr() == addr.ip()))
})
}
pub struct NetworkInterfaceListener {
pub trait Bind {
type Accept: Accept;
fn bind(&mut self, addr: SocketAddr) -> Result<Self::Accept, Error>;
}
#[derive(Clone, Copy, Default)]
pub struct BindTcp;
impl Bind for BindTcp {
type Accept = TcpListener;
fn bind(&mut self, addr: SocketAddr) -> Result<Self::Accept, Error> {
TcpListener::from_std(
mio::net::TcpListener::bind(addr)
.with_kind(ErrorKind::Network)?
.into(),
)
.with_kind(ErrorKind::Network)
}
}
pub trait FromGatewayInfo {
fn from_gateway_info(id: &GatewayId, info: &NetworkInterfaceInfo) -> Self;
}
#[derive(Clone, Debug)]
pub struct GatewayInfo {
pub id: GatewayId,
pub info: NetworkInterfaceInfo,
}
impl<V: MetadataVisitor> Visit<V> for GatewayInfo {
fn visit(&self, visitor: &mut V) -> <V as visit_rs::Visitor>::Result {
visitor.visit(self)
}
}
impl FromGatewayInfo for GatewayInfo {
fn from_gateway_info(id: &GatewayId, info: &NetworkInterfaceInfo) -> Self {
Self {
id: id.clone(),
info: info.clone(),
}
}
}
pub struct NetworkInterfaceListener<B: Bind = BindTcp> {
pub ip_info: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
activated: Watch<BTreeMap<GatewayId, bool>>,
listeners: ListenerMap,
listeners: ListenerMap<B>,
_arc: Arc<()>,
}
impl NetworkInterfaceListener {
pub fn port(&self) -> u16 {
self.listeners.port
}
#[cfg_attr(feature = "unstable", inline(never))]
pub fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
filter: &impl InterfaceFilter,
) -> Poll<Result<Accepted, Error>> {
while self.ip_info.poll_changed(cx).is_ready()
|| self.activated.poll_changed(cx).is_ready()
|| !DynInterfaceFilterT::eq(&self.listeners.prev_filter, filter.as_any())
{
let lxc_bridge = self.activated.peek(|a| {
a.get(NetworkInterfaceInfo::lxc_bridge().0)
.copied()
.unwrap_or_default()
});
self.ip_info
.peek_and_mark_seen(|ip_info| self.listeners.update(ip_info, lxc_bridge, filter))?;
}
self.listeners.poll_accept(cx)
}
impl<B: Bind> NetworkInterfaceListener<B> {
pub(super) fn new(
mut ip_info: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
activated: Watch<BTreeMap<GatewayId, bool>>,
bind: B,
port: u16,
) -> Self {
ip_info.mark_unseen();
Self {
ip_info,
activated,
listeners: ListenerMap::new(port),
listeners: ListenerMap::new(bind, port),
_arc: Arc::new(()),
}
}
pub fn port(&self) -> u16 {
self.listeners.port
}
#[cfg_attr(feature = "unstable", inline(never))]
pub fn poll_accept<M: FromGatewayInfo>(
&mut self,
cx: &mut std::task::Context<'_>,
filter: &impl InterfaceFilter,
) -> Poll<Result<(M, <B::Accept as Accept>::Metadata, AcceptStream), Error>> {
while self.ip_info.poll_changed(cx).is_ready()
|| !DynInterfaceFilterT::eq(&self.listeners.prev_filter, filter.as_any())
{
self.ip_info
.peek_and_mark_seen(|ip_info| self.listeners.update(ip_info, filter))?;
}
let (addr, inner, stream) = ready!(self.listeners.poll_accept(cx)?);
Poll::Ready(Ok((
self.ip_info
.peek(|ip_info| {
lookup_info_by_addr(ip_info, addr)
.map(|(id, info)| M::from_gateway_info(id, info))
})
.or_not_found(lazy_format!("gateway for {addr}"))?,
inner,
stream,
)))
}
pub fn change_ip_info_source(
&mut self,
mut ip_info: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
@@ -1515,7 +1514,10 @@ impl NetworkInterfaceListener {
self.ip_info = ip_info;
}
pub async fn accept(&mut self, filter: &impl InterfaceFilter) -> Result<Accepted, Error> {
pub async fn accept<M: FromGatewayInfo>(
&mut self,
filter: &impl InterfaceFilter,
) -> Result<(M, <B::Accept as Accept>::Metadata, AcceptStream), Error> {
futures::future::poll_fn(|cx| self.poll_accept(cx, filter)).await
}
@@ -1531,37 +1533,84 @@ impl NetworkInterfaceListener {
}
}
pub struct Accepted {
pub stream: TcpStream,
pub peer: SocketAddr,
pub wan_ip: Option<Ipv4Addr>,
pub bind: SocketAddr,
#[derive(VisitFields)]
pub struct NetworkInterfaceListenerAcceptMetadata<B: Bind> {
pub inner: <B::Accept as Accept>::Metadata,
pub info: GatewayInfo,
}
pub struct SelfContainedNetworkInterfaceListener {
_watch_thread: NonDetachingJoinHandle<()>,
listener: NetworkInterfaceListener,
}
impl SelfContainedNetworkInterfaceListener {
pub fn bind(port: u16) -> Self {
let ip_info = Watch::new(OrdMap::new());
let activated = Watch::new(
[(NetworkInterfaceInfo::lxc_bridge().0.clone(), false)]
.into_iter()
.collect(),
);
let _watch_thread = tokio::spawn(watcher(ip_info.clone(), activated.clone())).into();
impl<B: Bind> Clone for NetworkInterfaceListenerAcceptMetadata<B>
where
<B::Accept as Accept>::Metadata: Clone,
{
fn clone(&self) -> Self {
Self {
_watch_thread,
listener: NetworkInterfaceListener::new(ip_info, activated, port),
inner: self.inner.clone(),
info: self.info.clone(),
}
}
}
impl Accept for SelfContainedNetworkInterfaceListener {
impl<B, V> Visit<V> for NetworkInterfaceListenerAcceptMetadata<B>
where
B: Bind,
<B::Accept as Accept>::Metadata: Visit<V> + Clone + Send + Sync + 'static,
V: MetadataVisitor,
{
fn visit(&self, visitor: &mut V) -> V::Result {
self.visit_fields(visitor).collect()
}
}
impl<B: Bind> Accept for NetworkInterfaceListener<B> {
type Metadata = NetworkInterfaceListenerAcceptMetadata<B>;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<super::web_server::Accepted, Error>> {
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
NetworkInterfaceListener::poll_accept(self, cx, &true).map(|res| {
res.map(|(info, inner, stream)| {
(
NetworkInterfaceListenerAcceptMetadata { inner, info },
stream,
)
})
})
}
}
pub struct SelfContainedNetworkInterfaceListener<B: Bind = BindTcp> {
_watch_thread: NonDetachingJoinHandle<()>,
listener: NetworkInterfaceListener<B>,
}
impl<B: Bind> SelfContainedNetworkInterfaceListener<B> {
pub fn bind(bind: B, port: u16) -> Self {
let ip_info = Watch::new(OrdMap::new());
let _watch_thread =
tokio::spawn(watcher(ip_info.clone(), Watch::new(BTreeMap::new()))).into();
Self {
_watch_thread,
listener: NetworkInterfaceListener::new(ip_info, bind, port),
}
}
}
impl<B: Bind> Accept for SelfContainedNetworkInterfaceListener<B> {
type Metadata = <NetworkInterfaceListener<B> as Accept>::Metadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(Self::Metadata, AcceptStream), Error>> {
Accept::poll_accept(&mut self.listener, cx)
}
}
pub type UpgradableListener<B = BindTcp> =
Option<Either<SelfContainedNetworkInterfaceListener<B>, NetworkInterfaceListener<B>>>;
impl<B> Acceptor<UpgradableListener<B>>
where
B: Bind + Send + Sync + 'static,
B::Accept: Send + Sync,
{
pub fn bind_upgradable(listener: SelfContainedNetworkInterfaceListener<B>) -> Self {
Self::new(Some(Either::Left(listener)))
}
}

View File

@@ -4,17 +4,17 @@ use std::net::Ipv4Addr;
use clap::Parser;
use imbl_value::InternedString;
use models::GatewayId;
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerArgs, HandlerExt, ParentHandler};
use rpc_toolkit::{Context, Empty, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use crate::context::{CliContext, RpcContext};
use crate::db::model::DatabaseModel;
use crate::net::acme::AcmeProvider;
use crate::net::host::{all_hosts, HostApiKind};
use crate::net::host::{HostApiKind, all_hosts};
use crate::net::tor::OnionAddress;
use crate::prelude::*;
use crate::util::serde::{display_serializable, HandlerExtSerde};
use crate::util::serde::{HandlerExtSerde, display_serializable};
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
@@ -105,8 +105,8 @@ fn handle_duplicates(db: &mut DatabaseModel) -> Result<(), Error> {
Ok(())
}
pub fn address_api<C: Context, Kind: HostApiKind>(
) -> ParentHandler<C, Kind::Params, Kind::InheritedParams> {
pub fn address_api<C: Context, Kind: HostApiKind>()
-> ParentHandler<C, Kind::Params, Kind::InheritedParams> {
ParentHandler::<C, Kind::Params, Kind::InheritedParams>::new()
.subcommand(
"domain",
@@ -357,15 +357,7 @@ pub async fn add_onion<Kind: HostApiKind>(
OnionParams { onion }: OnionParams,
inheritance: Kind::Inheritance,
) -> Result<(), Error> {
let onion = onion
.strip_suffix(".onion")
.ok_or_else(|| {
Error::new(
eyre!("onion hostname must end in .onion"),
ErrorKind::InvalidOnionAddress,
)
})?
.parse::<OnionAddress>()?;
let onion = onion.parse::<OnionAddress>()?;
ctx.db
.mutate(|db| {
db.as_private().as_key_store().as_onion().get_key(&onion)?;
@@ -388,15 +380,7 @@ pub async fn remove_onion<Kind: HostApiKind>(
OnionParams { onion }: OnionParams,
inheritance: Kind::Inheritance,
) -> Result<(), Error> {
let onion = onion
.strip_suffix(".onion")
.ok_or_else(|| {
Error::new(
eyre!("onion hostname must end in .onion"),
ErrorKind::InvalidOnionAddress,
)
})?
.parse::<OnionAddress>()?;
let onion = onion.parse::<OnionAddress>()?;
ctx.db
.mutate(|db| {
Kind::host_for(&inheritance, db)?

View File

@@ -1,11 +1,11 @@
use std::collections::{BTreeMap, BTreeSet};
use std::str::FromStr;
use clap::builder::ValueParserFactory;
use clap::Parser;
use clap::builder::ValueParserFactory;
use imbl::OrdSet;
use models::{FromStrParser, GatewayId, HostId};
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerArgs, HandlerExt, ParentHandler};
use rpc_toolkit::{Context, Empty, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use ts_rs::TS;
@@ -16,7 +16,7 @@ use crate::net::gateway::InterfaceFilter;
use crate::net::host::HostApiKind;
use crate::net::vhost::AlpnInfo;
use crate::prelude::*;
use crate::util::serde::{display_serializable, HandlerExtSerde};
use crate::util::serde::{HandlerExtSerde, display_serializable};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, TS)]
#[ts(export)]
@@ -170,8 +170,8 @@ pub struct AddSslOptions {
pub alpn: Option<AlpnInfo>,
}
pub fn binding<C: Context, Kind: HostApiKind>(
) -> ParentHandler<C, Kind::Params, Kind::InheritedParams> {
pub fn binding<C: Context, Kind: HostApiKind>()
-> ParentHandler<C, Kind::Params, Kind::InheritedParams> {
ParentHandler::<C, Kind::Params, Kind::InheritedParams>::new()
.subcommand(
"list",

View File

@@ -6,15 +6,15 @@ use clap::Parser;
use imbl_value::InternedString;
use itertools::Itertools;
use models::{HostId, PackageId};
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerExt, OrEmpty, ParentHandler};
use rpc_toolkit::{Context, Empty, HandlerExt, OrEmpty, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use crate::context::RpcContext;
use crate::db::model::DatabaseModel;
use crate::net::forward::AvailablePorts;
use crate::net::host::address::{address_api, HostAddress, PublicDomainConfig};
use crate::net::host::binding::{binding, BindInfo, BindOptions};
use crate::net::host::address::{HostAddress, PublicDomainConfig, address_api};
use crate::net::host::binding::{BindInfo, BindOptions, binding};
use crate::net::service_interface::HostnameInfo;
use crate::net::tor::OnionAddress;
use crate::prelude::*;

View File

@@ -8,6 +8,7 @@ use crate::prelude::*;
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[model = "Model<Self>"]
#[serde(rename_all = "camelCase")]
pub struct KeyStore {
pub onion: OnionStore,
pub local_certs: CertStore,

View File

@@ -12,6 +12,7 @@ pub mod service_interface;
pub mod socks;
pub mod ssl;
pub mod static_server;
pub mod tls;
pub mod tor;
pub mod tunnel;
pub mod utils;

View File

@@ -1,46 +1,48 @@
use std::collections::{BTreeMap, BTreeSet};
use std::net::{Ipv4Addr, SocketAddr};
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::sync::{Arc, Weak};
use color_eyre::eyre::eyre;
use imbl::{vector, OrdMap};
use imbl::{OrdMap, vector};
use imbl_value::InternedString;
use ipnet::IpNet;
use models::{HostId, OptionExt, PackageId};
use models::{GatewayId, HostId, OptionExt, PackageId};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio_rustls::rustls::ClientConfig as TlsClientConfig;
use tracing::instrument;
use crate::db::model::public::{NetworkInterfaceInfo, NetworkInterfaceType};
use crate::HOST_IP;
use crate::db::model::Database;
use crate::db::model::public::NetworkInterfaceType;
use crate::error::ErrorCollection;
use crate::hostname::Hostname;
use crate::net::dns::DnsController;
use crate::net::forward::PortForwardController;
use crate::net::forward::{InterfacePortForwardController, START9_BRIDGE_IFACE};
use crate::net::gateway::{
AndFilter, DynInterfaceFilter, IdFilter, InterfaceFilter, NetworkInterfaceController, OrFilter,
PublicFilter, SecureFilter,
PublicFilter, SecureFilter, TypeFilter,
};
use crate::net::host::address::HostAddress;
use crate::net::host::binding::{AddSslOptions, BindId, BindOptions};
use crate::net::host::{host_for, Host, Hosts};
use crate::net::host::{Host, Hosts, host_for};
use crate::net::service_interface::{GatewayInfo, HostnameInfo, IpHostname, OnionHostname};
use crate::net::socks::SocksController;
use crate::net::tor::{OnionAddress, TorController, TorSecretKey};
use crate::net::utils::ipv6_is_local;
use crate::net::vhost::{AlpnInfo, TargetInfo, VHostController};
use crate::net::vhost::{AlpnInfo, DynVHostTarget, ProxyTarget, VHostController};
use crate::prelude::*;
use crate::service::effects::callbacks::ServiceCallbacks;
use crate::util::serde::MaybeUtf8String;
use crate::HOST_IP;
pub struct NetController {
pub(crate) db: TypedPatchDb<Database>,
pub(super) tor: TorController,
pub(super) vhost: VHostController,
pub(super) tls_client_config: Arc<TlsClientConfig>,
pub(crate) net_iface: Arc<NetworkInterfaceController>,
pub(super) dns: DnsController,
pub(super) forward: PortForwardController,
pub(super) forward: InterfacePortForwardController,
pub(super) socks: SocksController,
pub(super) server_hostnames: Vec<Option<InternedString>>,
pub(crate) callbacks: Arc<ServiceCallbacks>,
@@ -55,12 +57,26 @@ impl NetController {
let net_iface = Arc::new(NetworkInterfaceController::new(db.clone()));
let tor = TorController::new()?;
let socks = SocksController::new(socks_listen, tor.clone())?;
let crypto_provider = Arc::new(tokio_rustls::rustls::crypto::ring::default_provider());
let tls_client_config = Arc::new(crate::net::tls::client_config(
crypto_provider.clone(),
[&*db
.peek()
.await
.as_private()
.as_key_store()
.as_local_certs()
.as_root_cert()
.de()?
.0],
)?);
Ok(Self {
db: db.clone(),
tor,
vhost: VHostController::new(db.clone(), net_iface.clone()),
vhost: VHostController::new(db.clone(), net_iface.clone(), crypto_provider),
tls_client_config,
dns: DnsController::init(db, &net_iface.watcher).await?,
forward: PortForwardController::new(net_iface.watcher.subscribe()),
forward: InterfacePortForwardController::new(net_iface.watcher.subscribe()),
net_iface,
socks,
server_hostnames: vec![
@@ -133,8 +149,8 @@ impl NetController {
#[derive(Default, Debug)]
struct HostBinds {
forwards: BTreeMap<u16, (SocketAddr, DynInterfaceFilter, Arc<()>)>,
vhosts: BTreeMap<(Option<InternedString>, u16), (TargetInfo, Arc<()>)>,
forwards: BTreeMap<u16, (SocketAddrV4, DynInterfaceFilter, Arc<()>)>,
vhosts: BTreeMap<(Option<InternedString>, u16), (ProxyTarget, Arc<()>)>,
private_dns: BTreeMap<InternedString, Arc<()>>,
tor: BTreeMap<OnionAddress, (OrdMap<u16, SocketAddr>, Vec<Arc<()>>)>,
}
@@ -225,8 +241,8 @@ impl NetServiceData {
}
async fn update(&mut self, ctrl: &NetController, id: HostId, host: Host) -> Result<(), Error> {
let mut forwards: BTreeMap<u16, (SocketAddr, DynInterfaceFilter)> = BTreeMap::new();
let mut vhosts: BTreeMap<(Option<InternedString>, u16), TargetInfo> = BTreeMap::new();
let mut forwards: BTreeMap<u16, (SocketAddrV4, DynInterfaceFilter)> = BTreeMap::new();
let mut vhosts: BTreeMap<(Option<InternedString>, u16), ProxyTarget> = BTreeMap::new();
let mut private_dns: BTreeSet<InternedString> = BTreeSet::new();
let mut tor: BTreeMap<OnionAddress, (TorSecretKey, OrdMap<u16, SocketAddr>)> =
BTreeMap::new();
@@ -263,11 +279,13 @@ impl NetServiceData {
for hostname in ctrl.server_hostnames.iter().cloned() {
vhosts.insert(
(hostname, external),
TargetInfo {
ProxyTarget {
filter: bind.net.clone().into_dyn(),
acme: None,
addr,
connect_ssl: connect_ssl.clone(),
connect_ssl: connect_ssl
.clone()
.map(|_| ctrl.tls_client_config.clone()),
},
);
}
@@ -278,19 +296,19 @@ impl NetServiceData {
if hostnames.insert(hostname.clone()) {
vhosts.insert(
(Some(hostname), external),
TargetInfo {
ProxyTarget {
filter: OrFilter(
IdFilter(
NetworkInterfaceInfo::loopback().0.clone(),
),
IdFilter(
NetworkInterfaceInfo::lxc_bridge().0.clone(),
),
TypeFilter(NetworkInterfaceType::Loopback),
IdFilter(GatewayId::from(InternedString::from(
START9_BRIDGE_IFACE,
))),
)
.into_dyn(),
acme: None,
addr,
connect_ssl: connect_ssl.clone(),
connect_ssl: connect_ssl
.clone()
.map(|_| ctrl.tls_client_config.clone()),
},
); // TODO: wrap onion ssl stream directly in tor ctrl
}
@@ -306,7 +324,7 @@ impl NetServiceData {
if let Some(public) = &public {
vhosts.insert(
(address.clone(), 5443),
TargetInfo {
ProxyTarget {
filter: AndFilter(
bind.net.clone(),
AndFilter(
@@ -317,12 +335,14 @@ impl NetServiceData {
.into_dyn(),
acme: public.acme.clone(),
addr,
connect_ssl: connect_ssl.clone(),
connect_ssl: connect_ssl
.clone()
.map(|_| ctrl.tls_client_config.clone()),
},
);
vhosts.insert(
(address.clone(), 443),
TargetInfo {
ProxyTarget {
filter: AndFilter(
bind.net.clone(),
if private {
@@ -342,13 +362,15 @@ impl NetServiceData {
.into_dyn(),
acme: public.acme.clone(),
addr,
connect_ssl: connect_ssl.clone(),
connect_ssl: connect_ssl
.clone()
.map(|_| ctrl.tls_client_config.clone()),
},
);
} else {
vhosts.insert(
(address.clone(), 443),
TargetInfo {
ProxyTarget {
filter: AndFilter(
bind.net.clone(),
PublicFilter { public: false },
@@ -356,7 +378,9 @@ impl NetServiceData {
.into_dyn(),
acme: None,
addr,
connect_ssl: connect_ssl.clone(),
connect_ssl: connect_ssl
.clone()
.map(|_| ctrl.tls_client_config.clone()),
},
);
}
@@ -364,7 +388,7 @@ impl NetServiceData {
if let Some(public) = public {
vhosts.insert(
(address.clone(), external),
TargetInfo {
ProxyTarget {
filter: AndFilter(
bind.net.clone(),
if private {
@@ -381,13 +405,15 @@ impl NetServiceData {
.into_dyn(),
acme: public.acme.clone(),
addr,
connect_ssl: connect_ssl.clone(),
connect_ssl: connect_ssl
.clone()
.map(|_| ctrl.tls_client_config.clone()),
},
);
} else {
vhosts.insert(
(address.clone(), external),
TargetInfo {
ProxyTarget {
filter: AndFilter(
bind.net.clone(),
PublicFilter { public: false },
@@ -395,7 +421,9 @@ impl NetServiceData {
.into_dyn(),
acme: None,
addr,
connect_ssl: connect_ssl.clone(),
connect_ssl: connect_ssl
.clone()
.map(|_| ctrl.tls_client_config.clone()),
},
);
}
@@ -414,7 +442,7 @@ impl NetServiceData {
forwards.insert(
external,
(
(self.ip, *port).into(),
SocketAddrV4::new(self.ip, *port),
AndFilter(
SecureFilter {
secure: bind.options.secure.is_some(),
@@ -429,6 +457,11 @@ impl NetServiceData {
hostname_info.remove(port).unwrap_or_default();
for (gateway_id, info) in net_ifaces
.iter()
.filter(|(_, info)| {
info.ip_info.as_ref().map_or(false, |i| {
!matches!(i.device_type, Some(NetworkInterfaceType::Bridge))
})
})
.filter(|(id, info)| bind.net.filter(id, info))
{
let gateway = GatewayInfo {
@@ -653,7 +686,10 @@ impl NetServiceData {
if let Some(prev) = prev {
prev
} else {
(target.clone(), ctrl.vhost.add(key.0, key.1, target)?)
(
target.clone(),
ctrl.vhost.add(key.0, key.1, DynVHostTarget::new(target))?,
)
},
);
} else {
@@ -688,7 +724,7 @@ impl NetServiceData {
.collect::<BTreeSet<_>>();
for onion in all {
let mut prev = binds.tor.remove(&onion);
if let Some((key, tor_binds)) = tor.remove(&onion) {
if let Some((key, tor_binds)) = tor.remove(&onion).filter(|(_, b)| !b.is_empty()) {
prev = prev.filter(|(b, _)| b == &tor_binds);
binds.tor.insert(
onion,

View File

@@ -8,10 +8,10 @@ use socks5_impl::server::auth::NoAuth;
use socks5_impl::server::{AuthAdaptor, ClientConnection, Server};
use tokio::net::{TcpListener, TcpStream};
use crate::HOST_IP;
use crate::net::tor::TorController;
use crate::prelude::*;
use crate::util::actor::background::BackgroundJobQueue;
use crate::HOST_IP;
pub const DEFAULT_SOCKS_LISTEN: SocketAddr = SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(HOST_IP[0], HOST_IP[1], HOST_IP[2], HOST_IP[3]),

View File

@@ -2,6 +2,7 @@ use std::cmp::{Ordering, min};
use std::collections::{BTreeMap, BTreeSet};
use std::net::IpAddr;
use std::path::Path;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use futures::FutureExt;
@@ -10,22 +11,43 @@ use libc::time_t;
use openssl::asn1::{Asn1Integer, Asn1Time, Asn1TimeRef};
use openssl::bn::{BigNum, MsbOption};
use openssl::ec::{EcGroup, EcKey};
use openssl::error::ErrorStack;
use openssl::hash::MessageDigest;
use openssl::nid::Nid;
use openssl::pkey::{PKey, Private};
use openssl::x509::{X509, X509Builder, X509Extension, X509NameBuilder};
use openssl::x509::extension::{
AuthorityKeyIdentifier, BasicConstraints, KeyUsage, SubjectAlternativeName,
SubjectKeyIdentifier,
};
use openssl::x509::{X509, X509Builder, X509NameBuilder};
use openssl::*;
use patch_db::HasModel;
use serde::{Deserialize, Serialize};
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::rustls::crypto::CryptoProvider;
use tokio_rustls::rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio_rustls::rustls::server::ClientHello;
use tracing::instrument;
use visit_rs::Visit;
use crate::SOURCE_DATE;
use crate::account::AccountInfo;
use crate::db::model::Database;
use crate::db::{DbAccess, DbAccessMut};
use crate::hostname::Hostname;
use crate::init::check_time_is_synchronized;
use crate::net::gateway::GatewayInfo;
use crate::net::tls::TlsHandler;
use crate::net::web_server::{Accept, ExtractVisitor, TcpMetadata, extract};
use crate::prelude::*;
use crate::util::serde::Pem;
pub fn gen_nistp256() -> Result<PKey<Private>, ErrorStack> {
PKey::from_ec_key(EcKey::generate(&*EcGroup::from_curve_name(
Nid::X9_62_PRIME256V1,
)?)?)
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[model = "Model<Self>"]
#[serde(rename_all = "camelCase")]
@@ -96,9 +118,7 @@ impl Model<CertStore> {
} else {
PKeyPair {
ed25519: PKey::generate_ed25519()?,
nistp256: PKey::from_ec_key(EcKey::generate(&*EcGroup::from_curve_name(
Nid::X9_62_PRIME256V1,
)?)?)?,
nistp256: gen_nistp256()?,
}
};
let int_key = self.as_int_key().de()?.0;
@@ -125,6 +145,16 @@ impl Model<CertStore> {
})
}
}
impl DbAccess<CertStore> for Database {
fn access<'a>(db: &'a Model<Self>) -> &'a Model<CertStore> {
db.as_private().as_key_store().as_local_certs()
}
}
impl DbAccessMut<CertStore> for Database {
fn access_mut<'a>(db: &'a mut Model<Self>) -> &'a mut Model<CertStore> {
db.as_private_mut().as_key_store_mut().as_local_certs_mut()
}
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct CertData {
@@ -300,22 +330,16 @@ pub fn make_root_cert(
let cfg = conf::Conf::new(conf::ConfMethod::default())?;
let ctx = builder.x509v3_context(None, Some(&cfg));
// subjectKeyIdentifier = hash
let subject_key_identifier =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_KEY_IDENTIFIER, "hash")?;
let subject_key_identifier = SubjectKeyIdentifier::new().build(&ctx)?;
// basicConstraints = critical, CA:true, pathlen:0
let basic_constraints = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::BASIC_CONSTRAINTS,
"critical,CA:true",
)?;
let basic_constraints = BasicConstraints::new().critical().ca().build()?;
// keyUsage = critical, digitalSignature, cRLSign, keyCertSign
let key_usage = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::KEY_USAGE,
"critical,digitalSignature,cRLSign,keyCertSign",
)?;
let key_usage = KeyUsage::new()
.critical()
.digital_signature()
.crl_sign()
.key_cert_sign()
.build()?;
builder.append_extension(subject_key_identifier)?;
builder.append_extension(basic_constraints)?;
builder.append_extension(key_usage)?;
@@ -350,30 +374,23 @@ pub fn make_int_cert(
let cfg = conf::Conf::new(conf::ConfMethod::default())?;
let ctx = builder.x509v3_context(Some(&signer.1), Some(&cfg));
// subjectKeyIdentifier = hash
let subject_key_identifier =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_KEY_IDENTIFIER, "hash")?;
let subject_key_identifier = SubjectKeyIdentifier::new().build(&ctx)?;
// authorityKeyIdentifier = keyid:always,issuer
let authority_key_identifier = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::AUTHORITY_KEY_IDENTIFIER,
"keyid:always,issuer",
)?;
let authority_key_identifier = AuthorityKeyIdentifier::new()
.keyid(false)
.issuer(true)
.build(&ctx)?;
// basicConstraints = critical, CA:true, pathlen:0
let basic_constraints = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::BASIC_CONSTRAINTS,
"critical,CA:true,pathlen:0",
)?;
let basic_constraints = BasicConstraints::new().critical().ca().pathlen(0).build()?;
// keyUsage = critical, digitalSignature, cRLSign, keyCertSign
let key_usage = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::KEY_USAGE,
"critical,digitalSignature,cRLSign,keyCertSign",
)?;
let key_usage = KeyUsage::new()
.critical()
.digital_signature()
.crl_sign()
.key_cert_sign()
.build()?;
builder.append_extension(subject_key_identifier)?;
builder.append_extension(authority_key_identifier)?;
builder.append_extension(basic_constraints)?;
@@ -423,6 +440,16 @@ impl SANInfo {
}
Self { dns, ips }
}
pub fn x509_extension(&self) -> SubjectAlternativeName {
let mut san = SubjectAlternativeName::new();
for h in &self.dns {
san.dns(&h.as_str());
}
for ip in &self.ips {
san.ip(&ip.to_string());
}
san
}
}
impl std::fmt::Display for SANInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -485,28 +512,20 @@ pub fn make_leaf_cert(
// Extensions
let cfg = conf::Conf::new(conf::ConfMethod::default())?;
let ctx = builder.x509v3_context(Some(&signer.1), Some(&cfg));
// subjectKeyIdentifier = hash
let subject_key_identifier =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_KEY_IDENTIFIER, "hash")?;
// authorityKeyIdentifier = keyid:always,issuer
let authority_key_identifier = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::AUTHORITY_KEY_IDENTIFIER,
"keyid,issuer:always",
)?;
let basic_constraints =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::BASIC_CONSTRAINTS, "CA:FALSE")?;
let key_usage = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::KEY_USAGE,
"critical,digitalSignature,keyEncipherment",
)?;
let san_string = applicant.1.to_string();
let subject_alt_name =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_ALT_NAME, &san_string)?;
let subject_key_identifier = SubjectKeyIdentifier::new().build(&ctx)?;
let authority_key_identifier = AuthorityKeyIdentifier::new()
.keyid(false)
.issuer(true)
.build(&ctx)?;
let subject_alt_name = applicant.1.x509_extension().build(&ctx)?;
let basic_constraints = BasicConstraints::new().build()?;
let key_usage = KeyUsage::new()
.critical()
.digital_signature()
.key_encipherment()
.build()?;
builder.append_extension(subject_key_identifier)?;
builder.append_extension(authority_key_identifier)?;
builder.append_extension(subject_alt_name)?;
@@ -518,3 +537,161 @@ pub fn make_leaf_cert(
let cert = builder.build();
Ok(cert)
}
#[instrument(skip_all)]
pub fn make_self_signed(applicant: (&PKey<Private>, &SANInfo)) -> Result<X509, Error> {
let mut builder = X509Builder::new()?;
builder.set_version(CERTIFICATE_VERSION)?;
let embargo = Asn1Time::from_unix(unix_time(SystemTime::now()) - 86400)?;
builder.set_not_before(&embargo)?;
// Google Apple and Mozilla reject certificate horizons longer than 398 days
// https://techbeacon.com/security/google-apple-mozilla-enforce-1-year-max-security-certifications
let expiration = Asn1Time::days_from_now(397)?;
builder.set_not_after(&expiration)?;
builder.set_serial_number(&*rand_serial()?)?;
let mut subject_name_builder = X509NameBuilder::new()?;
subject_name_builder.append_entry_by_text(
"CN",
applicant
.1
.dns
.first()
.map(MaybeWildcard::as_str)
.unwrap_or("localhost"),
)?;
subject_name_builder.append_entry_by_text("O", "Start9")?;
subject_name_builder.append_entry_by_text("OU", "StartOS")?;
let subject_name = subject_name_builder.build();
builder.set_subject_name(&subject_name)?;
builder.set_issuer_name(&subject_name)?;
builder.set_pubkey(&applicant.0)?;
// Extensions
let cfg = conf::Conf::new(conf::ConfMethod::default())?;
let ctx = builder.x509v3_context(None, Some(&cfg));
let subject_key_identifier = SubjectKeyIdentifier::new().build(&ctx)?;
let authority_key_identifier = AuthorityKeyIdentifier::new()
.keyid(false)
.issuer(true)
.build(&ctx)?;
let subject_alt_name = applicant.1.x509_extension().build(&ctx)?;
let basic_constraints = BasicConstraints::new().build()?;
let key_usage = KeyUsage::new()
.critical()
.digital_signature()
.key_encipherment()
.build()?;
builder.append_extension(subject_key_identifier)?;
builder.append_extension(authority_key_identifier)?;
builder.append_extension(subject_alt_name)?;
builder.append_extension(basic_constraints)?;
builder.append_extension(key_usage)?;
builder.sign(&applicant.0, MessageDigest::sha256())?;
let cert = builder.build();
Ok(cert)
}
pub struct RootCaTlsHandler<M: HasModel> {
pub db: TypedPatchDb<M>,
pub crypto_provider: Arc<CryptoProvider>,
}
impl<M: HasModel> Clone for RootCaTlsHandler<M> {
fn clone(&self) -> Self {
Self {
db: self.db.clone(),
crypto_provider: self.crypto_provider.clone(),
}
}
}
impl<'a, A, M> TlsHandler<'a, A> for RootCaTlsHandler<M>
where
A: Accept + 'a,
<A as Accept>::Metadata: Visit<ExtractVisitor<TcpMetadata>>
+ Visit<ExtractVisitor<GatewayInfo>>
+ Clone
+ Send
+ Sync
+ 'static,
M: HasModel<Model = Model<M>> + DbAccessMut<CertStore> + Send + Sync,
{
async fn get_config(
&mut self,
hello: &ClientHello<'_>,
metadata: &<A as Accept>::Metadata,
) -> Option<ServerConfig> {
let hostnames: BTreeSet<InternedString> = hello
.server_name()
.map(InternedString::from)
.into_iter()
.chain(
extract::<TcpMetadata, _>(metadata)
.map(|m| m.local_addr.ip())
.as_ref()
.map(InternedString::from_display),
)
.chain(
extract::<GatewayInfo, _>(metadata)
.and_then(|i| i.info.ip_info)
.and_then(|i| i.wan_ip)
.as_ref()
.map(InternedString::from_display),
)
.collect();
let cert = self
.db
.mutate(|db| M::access_mut(db).cert_for(&hostnames))
.await
.result
.log_err()?;
let cfg = ServerConfig::builder_with_provider(self.crypto_provider.clone())
.with_safe_default_protocol_versions()
.log_err()?
.with_no_client_auth();
if hello
.signature_schemes()
.contains(&tokio_rustls::rustls::SignatureScheme::ED25519)
{
cfg.with_single_cert(
cert.fullchain_ed25519()
.into_iter()
.map(|c| {
Ok(tokio_rustls::rustls::pki_types::CertificateDer::from(
c.to_der()?,
))
})
.collect::<Result<_, Error>>()
.log_err()?,
PrivateKeyDer::from(PrivatePkcs8KeyDer::from(
cert.leaf.keys.ed25519.private_key_to_pkcs8().log_err()?,
)),
)
} else {
cfg.with_single_cert(
cert.fullchain_nistp256()
.into_iter()
.map(|c| {
Ok(tokio_rustls::rustls::pki_types::CertificateDer::from(
c.to_der()?,
))
})
.collect::<Result<_, Error>>()
.log_err()?,
PrivateKeyDer::from(PrivatePkcs8KeyDer::from(
cert.leaf.keys.nistp256.private_key_to_pkcs8().log_err()?,
)),
)
}
.log_err()
}
}

View File

@@ -9,7 +9,7 @@ use async_compression::tokio::bufread::GzipEncoder;
use axum::Router;
use axum::body::Body;
use axum::extract::{self as x, Request};
use axum::response::{Redirect, Response};
use axum::response::{IntoResponse, Redirect, Response};
use axum::routing::{any, get};
use base64::display::Base64Display;
use digest::Digest;
@@ -26,16 +26,19 @@ use models::PackageId;
use new_mime_guess::MimeGuess;
use openssl::hash::MessageDigest;
use openssl::x509::X509;
use rpc_toolkit::{Context, HttpServer, Server};
use rpc_toolkit::{Context, HttpServer, ParentHandler, Server};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeekExt, BufReader};
use tokio_util::io::ReaderStream;
use url::Url;
use crate::context::{DiagnosticContext, InitContext, InstallContext, RpcContext, SetupContext};
use crate::hostname::Hostname;
use crate::main_api;
use crate::middleware::auth::{Auth, HasValidSession};
use crate::middleware::cors::Cors;
use crate::middleware::db::SyncDb;
use crate::net::gateway::GatewayInfo;
use crate::net::tls::TlsHandshakeInfo;
use crate::prelude::*;
use crate::rpc_continuations::{Guid, RpcContinuations};
use crate::s9pk::S9pk;
@@ -46,7 +49,6 @@ use crate::sign::commitment::merkle_archive::MerkleArchiveCommitment;
use crate::util::io::open_file;
use crate::util::net::SyncBody;
use crate::util::serde::BASE64;
use crate::{diagnostic_api, init_api, install_api, main_api, setup_api};
const NOT_FOUND: &[u8] = b"Not Found";
const METHOD_NOT_ALLOWED: &[u8] = b"Method Not Allowed";
@@ -55,26 +57,151 @@ const INTERNAL_SERVER_ERROR: &[u8] = b"Internal Server Error";
const PROXY_STRIP_HEADERS: &[&str] = &["cookie", "host", "origin", "referer", "user-agent"];
#[cfg(all(feature = "startd", not(feature = "test")))]
const EMBEDDED_UIS: Dir<'_> =
include_dir::include_dir!("$CARGO_MANIFEST_DIR/../../web/dist/static");
#[cfg(not(all(feature = "startd", not(feature = "test"))))]
const EMBEDDED_UIS: Dir<'_> = Dir::new("", &[]);
pub const EMPTY_DIR: Dir<'_> = Dir::new("", &[]);
#[derive(Clone)]
pub enum UiMode {
Setup,
Install,
Main,
#[macro_export]
macro_rules! else_empty_dir {
($cfg:meta => $dir:expr) => {{
#[cfg(all($cfg, not(feature = "test")))]
{
$dir
}
#[cfg(not(all($cfg, not(feature = "test"))))]
{
crate::net::static_server::EMPTY_DIR
}
}};
}
impl UiMode {
fn path(&self, path: &str) -> PathBuf {
match self {
Self::Setup => Path::new("setup-wizard").join(path),
Self::Install => Path::new("install-wizard").join(path),
Self::Main => Path::new("ui").join(path),
const EMBEDDED_UI_ROOT: Dir<'_> = else_empty_dir!(
feature = "startd" =>
include_dir::include_dir!("$CARGO_MANIFEST_DIR/../../web/dist/static")
);
pub trait UiContext: Context + AsRef<RpcContinuations> + Clone + Sized {
const UI_DIR: &'static Dir<'static>;
fn api() -> ParentHandler<Self>;
fn middleware(server: Server<Self>) -> HttpServer<Self>;
fn extend_router(self, router: Router) -> Router {
router
}
}
impl UiContext for RpcContext {
const UI_DIR: &'static Dir<'static> = &else_empty_dir!(
feature = "startd" =>
include_dir::include_dir!("$CARGO_MANIFEST_DIR/../../web/dist/static/ui")
);
fn api() -> ParentHandler<Self> {
main_api()
}
fn middleware(server: Server<Self>) -> HttpServer<Self> {
server
.middleware(Cors::new())
.middleware(Auth::new())
.middleware(SyncDb::new())
}
fn extend_router(self, router: Router) -> Router {
async fn https_redirect_if_public_http(
req: Request,
next: axum::middleware::Next,
) -> Response {
if req
.extensions()
.get::<GatewayInfo>()
.map_or(false, |p| p.info.public())
&& req.extensions().get::<TlsHandshakeInfo>().is_none()
{
Redirect::temporary(&format!(
"https://{}{}",
req.headers()
.get(HOST)
.and_then(|s| s.to_str().ok())
.unwrap_or("localhost"),
req.uri()
))
.into_response()
} else {
next.run(req).await
}
}
router
.route("/proxy/{url}", {
let ctx = self.clone();
any(move |x::Path(url): x::Path<String>, request: Request| {
let ctx = ctx.clone();
async move {
proxy_request(ctx, request, url)
.await
.unwrap_or_else(server_error)
}
})
})
.nest("/s9pk", s9pk_router(self.clone()))
.route(
"/static/local-root-ca.crt",
get(move || {
let ctx = self.clone();
async move {
ctx.account
.peek(|account| cert_send(&account.root_ca_cert, &account.hostname))
}
}),
)
.layer(axum::middleware::from_fn(https_redirect_if_public_http))
}
}
impl UiContext for InitContext {
const UI_DIR: &'static Dir<'static> = &else_empty_dir!(
feature = "startd" =>
include_dir::include_dir!("$CARGO_MANIFEST_DIR/../../web/dist/static/ui")
);
fn api() -> ParentHandler<Self> {
main_api()
}
fn middleware(server: Server<Self>) -> HttpServer<Self> {
server.middleware(Cors::new())
}
}
impl UiContext for DiagnosticContext {
const UI_DIR: &'static Dir<'static> = &else_empty_dir!(
feature = "startd" =>
include_dir::include_dir!("$CARGO_MANIFEST_DIR/../../web/dist/static/ui")
);
fn api() -> ParentHandler<Self> {
main_api()
}
fn middleware(server: Server<Self>) -> HttpServer<Self> {
server.middleware(Cors::new())
}
}
impl UiContext for SetupContext {
const UI_DIR: &'static Dir<'static> = &else_empty_dir!(
feature = "startd" =>
include_dir::include_dir!("$CARGO_MANIFEST_DIR/../../web/dist/static/setup-wizard")
);
fn api() -> ParentHandler<Self> {
main_api()
}
fn middleware(server: Server<Self>) -> HttpServer<Self> {
server.middleware(Cors::new())
}
}
impl UiContext for InstallContext {
const UI_DIR: &'static Dir<'static> = &else_empty_dir!(
feature = "startd" =>
include_dir::include_dir!("$CARGO_MANIFEST_DIR/../../web/dist/static/install-wizard")
);
fn api() -> ParentHandler<Self> {
main_api()
}
fn middleware(server: Server<Self>) -> HttpServer<Self> {
server.middleware(Cors::new())
}
}
@@ -111,24 +238,23 @@ pub fn rpc_router<C: Context + Clone + AsRef<RpcContinuations>>(
)
}
fn serve_ui(req: Request, ui_mode: UiMode) -> Result<Response, Error> {
fn serve_ui<C: UiContext>(req: Request) -> Result<Response, Error> {
let (request_parts, _body) = req.into_parts();
match &request_parts.method {
&Method::GET | &Method::HEAD => {
let uri_path = ui_mode.path(
request_parts
.uri
.path()
.strip_prefix('/')
.unwrap_or(request_parts.uri.path()),
);
let uri_path = request_parts
.uri
.path()
.strip_prefix('/')
.unwrap_or(request_parts.uri.path());
let file = EMBEDDED_UIS
.get_file(&*uri_path)
.or_else(|| EMBEDDED_UIS.get_file(&*ui_mode.path("index.html")));
let file = C::UI_DIR
.get_file(uri_path)
.or_else(|| C::UI_DIR.get_file("index.html"));
if let Some(file) = file {
FileData::from_embedded(&request_parts, file)?.into_response(&request_parts)
FileData::from_embedded(&request_parts, file, C::UI_DIR)?
.into_response(&request_parts)
} else {
Ok(not_found())
}
@@ -137,79 +263,15 @@ fn serve_ui(req: Request, ui_mode: UiMode) -> Result<Response, Error> {
}
}
pub fn setup_ui_router(ctx: SetupContext) -> Router {
rpc_router(
ctx.clone(),
Server::new(move || ready(Ok(ctx.clone())), setup_api()).middleware(Cors::new()),
)
.fallback(any(|request: Request| async move {
serve_ui(request, UiMode::Setup).unwrap_or_else(server_error)
}))
}
pub fn diagnostic_ui_router(ctx: DiagnosticContext) -> Router {
rpc_router(
ctx.clone(),
Server::new(move || ready(Ok(ctx.clone())), diagnostic_api()).middleware(Cors::new()),
)
.fallback(any(|request: Request| async move {
serve_ui(request, UiMode::Main).unwrap_or_else(server_error)
}))
}
pub fn install_ui_router(ctx: InstallContext) -> Router {
rpc_router(
ctx.clone(),
Server::new(move || ready(Ok(ctx.clone())), install_api()).middleware(Cors::new()),
)
.fallback(any(|request: Request| async move {
serve_ui(request, UiMode::Install).unwrap_or_else(server_error)
}))
}
pub fn init_ui_router(ctx: InitContext) -> Router {
rpc_router(
ctx.clone(),
Server::new(move || ready(Ok(ctx.clone())), init_api()).middleware(Cors::new()),
)
.fallback(any(|request: Request| async move {
serve_ui(request, UiMode::Main).unwrap_or_else(server_error)
}))
}
pub fn main_ui_router(ctx: RpcContext) -> Router {
rpc_router(ctx.clone(), {
let ctx = ctx.clone();
Server::new(move || ready(Ok(ctx.clone())), main_api::<RpcContext>())
.middleware(Cors::new())
.middleware(Auth::new())
.middleware(SyncDb::new())
})
.route("/proxy/{url}", {
let ctx = ctx.clone();
any(move |x::Path(url): x::Path<String>, request: Request| {
let ctx = ctx.clone();
async move {
proxy_request(ctx, request, url)
.await
.unwrap_or_else(server_error)
}
})
})
.nest("/s9pk", s9pk_router(ctx.clone()))
.route(
"/static/local-root-ca.crt",
get(move || {
let ctx = ctx.clone();
async move {
let account = ctx.account.read().await;
cert_send(&account.root_ca_cert, &account.hostname)
}
}),
)
.fallback(any(|request: Request| async move {
serve_ui(request, UiMode::Main).unwrap_or_else(server_error)
}))
pub fn ui_router<C: UiContext>(ctx: C) -> Router {
ctx.clone()
.extend_router(rpc_router(
ctx.clone(),
C::middleware(Server::new(move || ready(Ok(ctx.clone())), C::api())),
))
.fallback(any(|request: Request| async move {
serve_ui::<C>(request).unwrap_or_else(server_error)
}))
}
pub fn refresher() -> Router {
@@ -229,20 +291,6 @@ pub fn refresher() -> Router {
}))
}
pub fn redirecter() -> Router {
Router::new().fallback(get(|request: Request| async move {
Redirect::temporary(&format!(
"https://{}{}",
request
.headers()
.get(HOST)
.and_then(|s| s.to_str().ok())
.unwrap_or("localhost"),
request.uri()
))
}))
}
async fn proxy_request(ctx: RpcContext, request: Request, url: String) -> Result<Response, Error> {
if_authorized(&ctx, request, |mut request| async {
for header in PROXY_STRIP_HEADERS {
@@ -492,6 +540,7 @@ impl FileData {
fn from_embedded(
req: &RequestParts,
file: &'static include_dir::File<'static>,
ui_dir: &'static Dir<'static>,
) -> Result<Self, Error> {
let path = file.path();
let (encoding, data, len, content_range) = if let Some(range) = req.headers.get(RANGE) {
@@ -533,12 +582,12 @@ impl FileData {
.fold((None, file.contents()), |acc, e| {
if let Some(file) = (e == "br")
.then_some(())
.and_then(|_| EMBEDDED_UIS.get_file(format!("{}.br", path.display())))
.and_then(|_| ui_dir.get_file(format!("{}.br", path.display())))
{
(Some("br"), file.contents())
} else if let Some(file) = (e == "gzip" && acc.0 != Some("br"))
.then_some(())
.and_then(|_| EMBEDDED_UIS.get_file(format!("{}.gz", path.display())))
.and_then(|_| ui_dir.get_file(format!("{}.gz", path.display())))
{
(Some("gzip"), file.contents())
} else {

315
core/startos/src/net/tls.rs Normal file
View File

@@ -0,0 +1,315 @@
use std::sync::Arc;
use std::task::{Poll, ready};
use futures::future::BoxFuture;
use futures::stream::FuturesUnordered;
use futures::{FutureExt, StreamExt};
use imbl_value::InternedString;
use openssl::x509::X509Ref;
use tokio::io::AsyncWriteExt;
use tokio_rustls::LazyConfigAcceptor;
use tokio_rustls::rustls::crypto::CryptoProvider;
use tokio_rustls::rustls::pki_types::CertificateDer;
use tokio_rustls::rustls::server::{Acceptor, ClientHello, ResolvesServerCert};
use tokio_rustls::rustls::sign::CertifiedKey;
use tokio_rustls::rustls::{ClientConfig, RootCertStore, ServerConfig};
use visit_rs::{Visit, VisitFields};
use crate::net::web_server::{Accept, AcceptStream, MetadataVisitor};
use crate::prelude::*;
use crate::util::io::{BackTrackingIO, ReadWriter};
use crate::util::serde::MaybeUtf8String;
use crate::util::sync::SyncMutex;
#[derive(Debug, Clone, VisitFields)]
pub struct TlsMetadata<M> {
pub inner: M,
pub tls_info: TlsHandshakeInfo,
}
impl<V: MetadataVisitor<Result = ()>, M: Visit<V>> Visit<V> for TlsMetadata<M> {
fn visit(&self, visitor: &mut V) -> <V as visit_rs::Visitor>::Result {
self.visit_fields(visitor).collect()
}
}
#[derive(Debug, Clone)]
pub struct TlsHandshakeInfo {
pub sni: Option<InternedString>,
pub alpn: Vec<MaybeUtf8String>,
}
impl<V: MetadataVisitor> Visit<V> for TlsHandshakeInfo {
fn visit(&self, visitor: &mut V) -> <V as visit_rs::Visitor>::Result {
visitor.visit(self)
}
}
pub trait TlsHandler<'a, A: Accept> {
fn get_config(
&'a mut self,
hello: &'a ClientHello<'a>,
metadata: &'a A::Metadata,
) -> impl Future<Output = Option<ServerConfig>> + Send + 'a;
}
#[derive(Clone)]
pub struct ChainedHandler<H0, H1>(pub H0, pub H1);
impl<'a, A, H0, H1> TlsHandler<'a, A> for ChainedHandler<H0, H1>
where
A: Accept + 'a,
<A as Accept>::Metadata: Send + Sync,
H0: TlsHandler<'a, A> + Send,
H1: TlsHandler<'a, A> + Send,
{
async fn get_config(
&'a mut self,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig> {
if let Some(config) = self.0.get_config(hello, metadata).await {
return Some(config);
}
self.1.get_config(hello, metadata).await
}
}
#[derive(Clone)]
pub struct TlsHandlerWrapper<I, W> {
pub inner: I,
pub wrapper: W,
}
pub trait WrapTlsHandler<A: Accept> {
fn wrap<'a>(
&'a mut self,
prev: ServerConfig,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> impl Future<Output = Option<ServerConfig>> + Send + 'a
where
Self: 'a;
}
impl<'a, A, I, W> TlsHandler<'a, A> for TlsHandlerWrapper<I, W>
where
A: Accept + 'a,
<A as Accept>::Metadata: Send + Sync,
I: TlsHandler<'a, A> + Send,
W: WrapTlsHandler<A> + Send,
{
async fn get_config(
&'a mut self,
hello: &'a ClientHello<'a>,
metadata: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig> {
let prev = self.inner.get_config(hello, metadata).await?;
self.wrapper.wrap(prev, hello, metadata).await
}
}
#[derive(Debug)]
pub struct SingleCertResolver(pub Arc<CertifiedKey>);
impl ResolvesServerCert for SingleCertResolver {
fn resolve(&self, _: ClientHello) -> Option<Arc<CertifiedKey>> {
Some(self.0.clone())
}
}
pub struct TlsListener<A: Accept, H: for<'a> TlsHandler<'a, A>> {
pub accept: A,
pub tls_handler: H,
in_progress: SyncMutex<
FuturesUnordered<
BoxFuture<
'static,
(
H,
Result<Option<(TlsMetadata<A::Metadata>, AcceptStream)>, Error>,
),
>,
>,
>,
}
impl<A: Accept, H: for<'a> TlsHandler<'a, A>> TlsListener<A, H> {
pub fn new(accept: A, cert_handler: H) -> Self {
Self {
accept,
tls_handler: cert_handler,
in_progress: SyncMutex::new(FuturesUnordered::new()),
}
}
}
impl<A, H> Accept for TlsListener<A, H>
where
A: Accept + 'static,
A::Metadata: Send + 'static,
for<'a> H: TlsHandler<'a, A> + Clone + Send + 'static,
{
type Metadata = TlsMetadata<A::Metadata>;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
self.in_progress.mutate(|in_progress| {
loop {
if !in_progress.is_empty() {
if let Poll::Ready(Some((handler, res))) = in_progress.poll_next_unpin(cx) {
if let Some(res) = res.transpose() {
self.tls_handler = handler;
return Poll::Ready(res);
}
continue;
}
}
let (metadata, stream) = ready!(self.accept.poll_accept(cx)?);
let mut tls_handler = self.tls_handler.clone();
let mut fut = async move {
let res = async {
let mut acceptor = LazyConfigAcceptor::new(
Acceptor::default(),
BackTrackingIO::new(stream),
);
let mut mid: tokio_rustls::StartHandshake<BackTrackingIO<AcceptStream>> =
match (&mut acceptor).await {
Ok(a) => a,
Err(e) => {
let mut stream =
acceptor.take_io().or_not_found("acceptor io")?;
let (_, buf) = stream.rewind();
if std::str::from_utf8(buf)
.ok()
.and_then(|buf| {
buf.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty())
.next()
})
.map_or(false, |buf| {
regex::Regex::new("[A-Z]+ (.+) HTTP/1")
.unwrap()
.is_match(buf)
})
{
handle_http_on_https(stream).await.log_err();
return Ok(None);
} else {
return Err(e).with_kind(ErrorKind::Network);
}
}
};
let hello = mid.client_hello();
if let Some(cfg) = tls_handler.get_config(&hello, &metadata).await {
let metadata = TlsMetadata {
inner: metadata,
tls_info: TlsHandshakeInfo {
sni: hello.server_name().map(InternedString::intern),
alpn: hello
.alpn()
.into_iter()
.flatten()
.map(|a| MaybeUtf8String(a.to_vec()))
.collect(),
},
};
let buffered = mid.io.stop_buffering();
mid.io
.write_all(&buffered)
.await
.with_kind(ErrorKind::Network)?;
return Ok(Some((
metadata,
Box::pin(mid.into_stream(Arc::new(cfg)).await?) as AcceptStream,
)));
}
Ok(None)
}
.await;
(tls_handler, res)
}
.boxed();
match fut.poll_unpin(cx) {
Poll::Pending => {
in_progress.push(fut);
return Poll::Pending;
}
Poll::Ready((handler, res)) => {
if let Some(res) = res.transpose() {
self.tls_handler = handler;
return Poll::Ready(res);
}
}
};
}
})
}
}
async fn handle_http_on_https(stream: impl ReadWriter + Unpin + 'static) -> Result<(), Error> {
use axum::body::Body;
use axum::extract::Request;
use axum::response::Response;
use http::Uri;
use crate::net::static_server::server_error;
hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
.serve_connection(
hyper_util::rt::TokioIo::new(stream),
hyper_util::service::TowerToHyperService::new(axum::Router::new().fallback(
axum::routing::method_routing::any(move |req: Request| async move {
match async move {
let host = req
.headers()
.get(http::header::HOST)
.and_then(|host| host.to_str().ok());
if let Some(host) = host {
let uri = Uri::from_parts({
let mut parts = req.uri().to_owned().into_parts();
parts.scheme = Some("https".parse()?);
parts.authority = Some(host.parse()?);
parts
})?;
Response::builder()
.status(http::StatusCode::TEMPORARY_REDIRECT)
.header(http::header::LOCATION, uri.to_string())
.body(Body::default())
} else {
Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(Body::from("Host header required"))
}
}
.await
{
Ok(a) => a,
Err(e) => {
tracing::warn!("Error redirecting http request on ssl port: {e}");
tracing::error!("{e:?}");
server_error(Error::new(e, ErrorKind::Network))
}
}
}),
)),
)
.await
.map_err(|e| Error::new(color_eyre::eyre::Report::msg(e), ErrorKind::Network))
}
pub fn client_config<'a, I: IntoIterator<Item = &'a X509Ref>>(
crypto_provider: Arc<CryptoProvider>,
root_certs: I,
) -> Result<ClientConfig, Error> {
let mut certs = RootCertStore::empty();
for cert in root_certs {
certs
.add(CertificateDer::from_slice(&cert.to_der()?))
.with_kind(ErrorKind::OpenSsl)?;
}
Ok(ClientConfig::builder_with_provider(crypto_provider.clone())
.with_safe_default_protocol_versions()
.with_kind(ErrorKind::OpenSsl)?
.with_root_certificates(certs)
.with_no_client_auth())
}

View File

@@ -6,7 +6,7 @@ use std::sync::{Arc, Weak};
use std::time::{Duration, Instant};
use arti_client::config::onion_service::OnionServiceConfigBuilder;
use arti_client::{DataStream, TorClient, TorClientConfig};
use arti_client::{TorClient, TorClientConfig};
use base64::Engine;
use clap::Parser;
use color_eyre::eyre::eyre;
@@ -14,7 +14,7 @@ use futures::{FutureExt, StreamExt};
use helpers::NonDetachingJoinHandle;
use imbl_value::InternedString;
use itertools::Itertools;
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerExt, ParentHandler};
use rpc_toolkit::{Context, Empty, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
@@ -34,8 +34,8 @@ use crate::util::actor::background::BackgroundJobQueue;
use crate::util::future::Until;
use crate::util::io::ReadWriter;
use crate::util::serde::{
deserialize_from_str, display_serializable, serialize_display, Base64, HandlerExtSerde,
WithIoFormat, BASE64,
BASE64, Base64, HandlerExtSerde, WithIoFormat, deserialize_from_str, display_serializable,
serialize_display,
};
use crate::util::sync::{SyncMutex, SyncRwLock, Watch};
@@ -628,11 +628,7 @@ impl TorController {
} else {
false
};
if rm {
s.remove(&addr)
} else {
None
}
if rm { s.remove(&addr) } else { None }
}) {
s.shutdown().await
} else {
@@ -861,11 +857,11 @@ impl OnionService {
})))
}
pub fn proxy_all<Rcs: FromIterator<Arc<()>>>(
pub async fn proxy_all<Rcs: FromIterator<Arc<()>>>(
&self,
bindings: impl IntoIterator<Item = (u16, SocketAddr)>,
) -> Rcs {
self.0.bindings.mutate(|b| {
) -> Result<Rcs, Error> {
Ok(self.0.bindings.mutate(|b| {
bindings
.into_iter()
.map(|(port, target)| {
@@ -879,7 +875,7 @@ impl OnionService {
}
})
.collect()
})
}))
}
pub fn gc(&self) -> bool {

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,10 @@
#[cfg(feature = "arti")]
mod arti;
#[cfg(not(feature = "arti"))]
mod ctor;
#[cfg(feature = "arti")]
pub use arti::{OnionAddress, OnionStore, TorController, TorSecretKey, tor_api};
#[cfg(not(feature = "arti"))]
pub use ctor::{OnionAddress, OnionStore, TorController, TorSecretKey, tor_api};

View File

@@ -2,16 +2,17 @@ use clap::Parser;
use imbl_value::InternedString;
use models::GatewayId;
use patch_db::json_ptr::JsonPointer;
use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler};
use rpc_toolkit::{Context, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use ts_rs::TS;
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::{NetworkInterfaceInfo, NetworkInterfaceType};
use crate::net::host::all_hosts;
use crate::prelude::*;
use crate::util::io::{write_file_atomic, TmpDir};
use crate::util::Invoke;
use crate::util::io::{TmpDir, write_file_atomic};
pub fn tunnel_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
@@ -125,14 +126,44 @@ pub async fn remove_tunnel(
return Ok(());
};
if existing.as_device_type().de()? != Some(NetworkInterfaceType::Wireguard) {
if existing.as_deref().as_device_type().de()? != Some(NetworkInterfaceType::Wireguard) {
return Err(Error::new(
eyre!("network interface {id} is not a proxy"),
ErrorKind::InvalidRequest,
));
}
ctx.db
.mutate(|db| {
for host in all_hosts(db) {
let host = host?;
host.as_public_domains_mut()
.mutate(|p| Ok(p.retain(|_, v| v.gateway != id)))?;
}
Ok(())
})
.await
.result?;
ctx.net_controller.net_iface.delete_iface(&id).await?;
ctx.db
.mutate(|db| {
for host in all_hosts(db) {
let host = host?;
host.as_bindings_mut().mutate(|b| {
Ok(b.values_mut().for_each(|v| {
v.net.private_disabled.remove(&id);
v.net.public_enabled.remove(&id);
}))
})?;
}
Ok(())
})
.await
.result?;
Ok(())
}

View File

@@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV6};
use std::path::Path;
@@ -7,13 +8,56 @@ use futures::stream::BoxStream;
use futures::{StreamExt, TryStreamExt};
use imbl_value::InternedString;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use models::GatewayId;
use nix::net::if_::if_nametoindex;
use tokio::net::{TcpListener, TcpStream};
use tokio::process::Command;
use crate::db::model::public::{IpInfo, NetworkInterfaceType};
use crate::prelude::*;
use crate::util::Invoke;
pub async fn load_ip_info() -> Result<BTreeMap<GatewayId, IpInfo>, Error> {
let output = String::from_utf8(
Command::new("ip")
.arg("-o")
.arg("addr")
.arg("show")
.invoke(crate::ErrorKind::Network)
.await?,
)?;
let err_fn = || {
Error::new(
eyre!("malformed output from `ip`"),
crate::ErrorKind::Network,
)
};
let mut res = BTreeMap::<GatewayId, IpInfo>::new();
for line in output.lines() {
let split = line.split_ascii_whitespace().collect::<Vec<_>>();
let iface = GatewayId::from(InternedString::from(*split.get(1).ok_or_else(&err_fn)?));
let subnet: IpNet = split.get(3).ok_or_else(&err_fn)?.parse()?;
let ip_info = res.entry(iface.clone()).or_default();
ip_info.name = iface.into();
ip_info.scope_id = split
.get(0)
.ok_or_else(&err_fn)?
.strip_suffix(":")
.ok_or_else(&err_fn)?
.parse()?;
ip_info.subnets.insert(subnet);
}
for (id, ip_info) in res.iter_mut() {
ip_info.device_type = probe_iface_type(id.as_str()).await;
}
Ok(res)
}
pub fn ipv6_is_link_local(addr: Ipv6Addr) -> bool {
(addr.segments()[0] & 0xffc0) == 0xfe80
}
@@ -75,6 +119,22 @@ pub async fn get_iface_ipv6_addr(iface: &str) -> Result<Option<(Ipv6Addr, Ipv6Ne
.transpose()?)
}
pub async fn probe_iface_type(iface: &str) -> Option<NetworkInterfaceType> {
match tokio::fs::read_to_string(Path::new("/sys/class/net").join(iface).join("uevent"))
.await
.ok()?
.lines()
.find_map(|l| l.strip_prefix("DEVTYPE="))
{
Some("wlan") => Some(NetworkInterfaceType::Wireless),
Some("bridge") => Some(NetworkInterfaceType::Bridge),
Some("wireguard") => Some(NetworkInterfaceType::Wireguard),
None if iface_is_physical(iface).await => Some(NetworkInterfaceType::Ethernet),
None if iface_is_loopback(iface).await => Some(NetworkInterfaceType::Loopback),
_ => None,
}
}
pub async fn iface_is_physical(iface: &str) -> bool {
tokio::fs::metadata(Path::new("/sys/class/net").join(iface).join("device"))
.await
@@ -87,6 +147,19 @@ pub async fn iface_is_wireless(iface: &str) -> bool {
.is_ok()
}
pub async fn iface_is_bridge(iface: &str) -> bool {
tokio::fs::metadata(Path::new("/sys/class/net").join(iface).join("bridge"))
.await
.is_ok()
}
pub async fn iface_is_loopback(iface: &str) -> bool {
tokio::fs::read_to_string(Path::new("/sys/class/net").join(iface).join("type"))
.await
.ok()
.map_or(false, |x| x.trim() == "772")
}
pub fn list_interfaces() -> BoxStream<'static, Result<String, Error>> {
try_stream! {
let mut ifaces = tokio::fs::read_dir("/sys/class/net").await?;

File diff suppressed because it is too large Load Diff

View File

@@ -1,70 +1,200 @@
use std::any::Any;
use std::collections::BTreeMap;
use std::future::Future;
use std::net::SocketAddr;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use std::task::{Poll, ready};
use std::time::Duration;
use axum::Router;
use futures::future::Either;
use futures::FutureExt;
use futures::{FutureExt, TryFutureExt};
use helpers::NonDetachingJoinHandle;
use http::Extensions;
use hyper_util::rt::{TokioIo, TokioTimer};
use tokio::net::{TcpListener, TcpStream};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use visit_rs::{Visit, VisitFields, Visitor};
use crate::context::{DiagnosticContext, InitContext, InstallContext, RpcContext, SetupContext};
use crate::net::gateway::{
lookup_info_by_addr, NetworkInterfaceListener, SelfContainedNetworkInterfaceListener,
};
use crate::net::static_server::{
diagnostic_ui_router, init_ui_router, install_ui_router, main_ui_router, redirecter, refresher,
setup_ui_router,
};
use crate::net::static_server::{UiContext, ui_router};
use crate::prelude::*;
use crate::util::actor::background::BackgroundJobQueue;
use crate::util::io::ReadWriter;
use crate::util::sync::{SyncRwLock, Watch};
pub struct Accepted {
pub https_redirect: bool,
pub stream: TcpStream,
pub type AcceptStream = Pin<Box<dyn ReadWriter + Send + 'static>>;
pub trait MetadataVisitor: Visitor<Result = ()> {
fn visit<M: Clone + Send + Sync + 'static>(&mut self, metadata: &M) -> Self::Result;
}
pub struct ExtensionVisitor<'a>(&'a mut Extensions);
impl<'a> Visitor for ExtensionVisitor<'a> {
type Result = ();
}
impl<'a> MetadataVisitor for ExtensionVisitor<'a> {
fn visit<M: Clone + Send + Sync + 'static>(&mut self, metadata: &M) -> Self::Result {
self.0.insert(metadata.clone());
}
}
impl<'a> Visit<ExtensionVisitor<'a>>
for Box<dyn for<'x> Visit<ExtensionVisitor<'x>> + Send + Sync + 'static>
{
fn visit(
&self,
visitor: &mut ExtensionVisitor<'a>,
) -> <ExtensionVisitor<'a> as Visitor>::Result {
(&**self).visit(visitor)
}
}
pub struct ExtractVisitor<T>(Option<T>);
impl<T> Visitor for ExtractVisitor<T> {
type Result = ();
}
impl<T: Clone + Send + Sync + 'static> MetadataVisitor for ExtractVisitor<T> {
fn visit<M: Clone + Send + Sync + 'static>(&mut self, metadata: &M) -> Self::Result {
if let Some(matching) = (metadata as &dyn Any).downcast_ref::<T>() {
self.0 = Some(matching.clone());
}
}
}
pub fn extract<
T: Clone + Send + Sync + 'static,
M: Visit<ExtractVisitor<T>> + Clone + Send + Sync + 'static,
>(
metadata: &M,
) -> Option<T> {
let mut visitor = ExtractVisitor(None);
visitor.visit(metadata);
visitor.0
}
#[derive(Clone, Copy, Debug)]
pub struct TcpMetadata {
pub peer_addr: SocketAddr,
pub local_addr: SocketAddr,
}
impl<V: MetadataVisitor> Visit<V> for TcpMetadata {
fn visit(&self, visitor: &mut V) -> <V as visit_rs::Visitor>::Result {
visitor.visit(self)
}
}
pub trait Accept {
fn poll_accept(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>>;
type Metadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>>;
fn into_dyn(self) -> DynAccept
where
Self: Sized + Send + Sync + 'static,
for<'a> Self::Metadata: Visit<ExtensionVisitor<'a>> + Send + Sync + 'static,
{
DynAccept::new(self)
}
}
impl Accept for Vec<TcpListener> {
fn poll_accept(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>> {
for listener in &*self {
if let Poll::Ready((stream, _)) = listener.poll_accept(cx)? {
return Poll::Ready(Ok(Accepted {
https_redirect: false,
stream,
}));
impl Accept for TcpListener {
type Metadata = TcpMetadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
if let Poll::Ready((stream, peer_addr)) = TcpListener::poll_accept(self, cx)? {
if let Err(e) = socket2::SockRef::from(&stream).set_tcp_keepalive(
&socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(900))
.with_interval(Duration::from_secs(60))
.with_retries(5),
) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
return Poll::Ready(Ok((
TcpMetadata {
local_addr: self.local_addr()?,
peer_addr,
},
Box::pin(stream),
)));
}
Poll::Pending
}
}
impl<A> Accept for Vec<A>
where
A: Accept,
{
type Metadata = A::Metadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
for listener in self {
if let Poll::Ready(accepted) = listener.poll_accept(cx)? {
return Poll::Ready(Ok(accepted));
}
}
Poll::Pending
}
}
impl Accept for NetworkInterfaceListener {
fn poll_accept(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>> {
NetworkInterfaceListener::poll_accept(self, cx, &true).map(|res| {
res.map(|a| {
let public = self
.ip_info
.peek(|i| lookup_info_by_addr(i, a.bind).map_or(true, |(_, i)| i.public()));
Accepted {
https_redirect: public,
stream: a.stream,
}
})
})
#[derive(Clone, VisitFields)]
pub struct MapListenerMetadata<K, M> {
pub inner: M,
pub key: K,
}
impl<K, M, V> Visit<V> for MapListenerMetadata<K, M>
where
V: MetadataVisitor,
K: Visit<V>,
M: Visit<V>,
{
fn visit(&self, visitor: &mut V) -> <V as Visitor>::Result {
self.visit_fields(visitor).collect()
}
}
impl<A: Accept, B: Accept> Accept for Either<A, B> {
fn poll_accept(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>> {
impl<K, A> Accept for BTreeMap<K, A>
where
K: Clone,
A: Accept,
{
type Metadata = MapListenerMetadata<K, A::Metadata>;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
for (key, listener) in self {
if let Poll::Ready((metadata, stream)) = listener.poll_accept(cx)? {
return Poll::Ready(Ok((
MapListenerMetadata {
inner: metadata,
key: key.clone(),
},
stream,
)));
}
}
Poll::Pending
}
}
impl<A, B> Accept for Either<A, B>
where
A: Accept,
B: Accept<Metadata = A::Metadata>,
{
type Metadata = A::Metadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
match self {
Either::Left(a) => a.poll_accept(cx),
Either::Right(b) => b.poll_accept(cx),
@@ -72,7 +202,11 @@ impl<A: Accept, B: Accept> Accept for Either<A, B> {
}
}
impl<A: Accept> Accept for Option<A> {
fn poll_accept(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>> {
type Metadata = A::Metadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
match self {
None => Poll::Pending,
Some(a) => a.poll_accept(cx),
@@ -80,6 +214,68 @@ impl<A: Accept> Accept for Option<A> {
}
}
trait DynAcceptT: Send + Sync {
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<
Result<
(
Box<dyn for<'a> Visit<ExtensionVisitor<'a>> + Send + Sync>,
AcceptStream,
),
Error,
>,
>;
}
impl<A> DynAcceptT for A
where
A: Accept + Send + Sync,
for<'a> <A as Accept>::Metadata: Visit<ExtensionVisitor<'a>> + Send + Sync + 'static,
{
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<
Result<
(
Box<dyn for<'a> Visit<ExtensionVisitor<'a>> + Send + Sync>,
AcceptStream,
),
Error,
>,
> {
let (metadata, stream) = ready!(Accept::poll_accept(self, cx)?);
Poll::Ready(Ok((Box::new(metadata), stream)))
}
}
pub struct DynAccept(Box<dyn DynAcceptT>);
impl Accept for DynAccept {
type Metadata = Box<dyn for<'a> Visit<ExtensionVisitor<'a>> + Send + Sync>;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
DynAcceptT::poll_accept(&mut *self.0, cx)
}
fn into_dyn(self) -> DynAccept
where
Self: Sized,
for<'a> Self::Metadata: Visit<ExtensionVisitor<'a>> + Send + Sync + 'static,
{
self
}
}
impl DynAccept {
pub fn new<A>(accept: A) -> Self
where
A: Accept + Send + Sync + 'static,
for<'a> <A as Accept>::Metadata: Visit<ExtensionVisitor<'a>> + Send + Sync + 'static,
{
Self(Box::new(accept))
}
}
#[pin_project::pin_project]
pub struct Acceptor<A: Accept> {
acceptor: Watch<A>,
@@ -95,12 +291,15 @@ impl<A: Accept + Send + Sync + 'static> Acceptor<A> {
self.acceptor.poll_changed(cx)
}
fn poll_accept(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<Accepted, Error>> {
let _ = self.poll_changed(cx);
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(A::Metadata, AcceptStream), Error>> {
while self.poll_changed(cx).is_ready() {}
self.acceptor.peek_mut(|a| a.poll_accept(cx))
}
async fn accept(&mut self) -> Result<Accepted, Error> {
async fn accept(&mut self) -> Result<(A::Metadata, AcceptStream), Error> {
std::future::poll_fn(|cx| self.poll_accept(cx)).await
}
}
@@ -111,20 +310,73 @@ impl Acceptor<Vec<TcpListener>> {
))
}
}
pub type UpgradableListener =
Option<Either<SelfContainedNetworkInterfaceListener, NetworkInterfaceListener>>;
impl Acceptor<UpgradableListener> {
pub fn bind_upgradable(listener: SelfContainedNetworkInterfaceListener) -> Self {
Self::new(Some(Either::Left(listener)))
impl Acceptor<Vec<DynAccept>> {
pub async fn bind_dyn(listen: impl IntoIterator<Item = SocketAddr>) -> Result<Self, Error> {
Ok(Self::new(
futures::future::try_join_all(
listen
.into_iter()
.map(TcpListener::bind)
.map(|f| f.map_ok(DynAccept::new)),
)
.await?,
))
}
}
impl<K> Acceptor<BTreeMap<K, TcpListener>>
where
K: Ord + Clone + Send + Sync + 'static,
{
pub async fn bind_map(
listen: impl IntoIterator<Item = (K, SocketAddr)>,
) -> Result<Self, Error> {
Ok(Self::new(
futures::future::try_join_all(listen.into_iter().map(|(key, addr)| async move {
Ok::<_, Error>((
key,
TcpListener::bind(addr)
.await
.with_kind(ErrorKind::Network)?,
))
}))
.await?
.into_iter()
.collect(),
))
}
}
impl<K> Acceptor<BTreeMap<K, DynAccept>>
where
K: Ord + Clone + Send + Sync + 'static,
{
pub async fn bind_map_dyn(
listen: impl IntoIterator<Item = (K, SocketAddr)>,
) -> Result<Self, Error> {
Ok(Self::new(
futures::future::try_join_all(listen.into_iter().map(|(key, addr)| async move {
Ok::<_, Error>((
key,
TcpListener::bind(addr)
.await
.with_kind(ErrorKind::Network)?,
))
}))
.await?
.into_iter()
.map(|(key, listener)| (key, listener.into_dyn()))
.collect(),
))
}
}
pub struct WebServerAcceptorSetter<A: Accept> {
acceptor: Watch<A>,
}
impl<A: Accept, B: Accept> WebServerAcceptorSetter<Option<Either<A, B>>> {
impl<A, B> WebServerAcceptorSetter<Option<Either<A, B>>>
where
A: Accept,
B: Accept<Metadata = A::Metadata>,
{
pub fn try_upgrade<F: FnOnce(A) -> Result<B, Error>>(&self, f: F) -> Result<(), Error> {
let mut res = Ok(());
self.acceptor.send_modify(|a| {
@@ -151,20 +403,24 @@ impl<A: Accept> Deref for WebServerAcceptorSetter<A> {
pub struct WebServer<A: Accept> {
shutdown: oneshot::Sender<()>,
router: Watch<Option<Router>>,
router: Watch<Router>,
acceptor: Watch<A>,
thread: NonDetachingJoinHandle<()>,
}
impl<A: Accept + Send + Sync + 'static> WebServer<A> {
impl<A> WebServer<A>
where
A: Accept + Send + Sync + 'static,
for<'a> A::Metadata: Visit<ExtensionVisitor<'a>> + Send + Sync + 'static,
{
pub fn acceptor_setter(&self) -> WebServerAcceptorSetter<A> {
WebServerAcceptorSetter {
acceptor: self.acceptor.clone(),
}
}
pub fn new(mut acceptor: Acceptor<A>) -> Self {
pub fn new(mut acceptor: Acceptor<A>, router: Router) -> Self {
let acceptor_send = acceptor.acceptor.clone();
let router = Watch::<Option<Router>>::new(None);
let router = Watch::new(router);
let service = router.clone_unseen();
let (shutdown, shutdown_recv) = oneshot::channel();
let thread = NonDetachingJoinHandle::from(tokio::spawn(async move {
@@ -187,8 +443,14 @@ impl<A: Accept + Send + Sync + 'static> WebServer<A> {
}
}
struct SwappableRouter(Watch<Option<Router>>, bool);
impl hyper::service::Service<hyper::Request<hyper::body::Incoming>> for SwappableRouter {
struct SwappableRouter<M> {
router: Watch<Router>,
metadata: M,
}
impl<M: for<'a> Visit<ExtensionVisitor<'a>> + Send + Sync + 'static>
hyper::service::Service<hyper::Request<hyper::body::Incoming>>
for SwappableRouter<M>
{
type Response = <Router as tower_service::Service<
hyper::Request<hyper::body::Incoming>,
>>::Response;
@@ -199,19 +461,13 @@ impl<A: Accept + Send + Sync + 'static> WebServer<A> {
hyper::Request<hyper::body::Incoming>,
>>::Future;
fn call(&self, req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
fn call(&self, mut req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
use tower_service::Service;
if self.1 {
redirecter().call(req)
} else {
let router = self.0.read();
if let Some(mut router) = router {
router.call(req)
} else {
refresher().call(req)
}
}
self.metadata
.visit(&mut ExtensionVisitor(req.extensions_mut()));
self.router.read().call(req)
}
}
@@ -238,16 +494,16 @@ impl<A: Accept + Send + Sync + 'static> WebServer<A> {
let mut err = None;
for _ in 0..5 {
if let Err(e) = async {
let accepted = acceptor.accept().await?;
let (metadata, stream) = acceptor.accept().await?;
queue.add_job(
graceful.watch(
server
.serve_connection_with_upgrades(
TokioIo::new(accepted.stream),
SwappableRouter(
service.clone(),
accepted.https_redirect,
),
TokioIo::new(stream),
SwappableRouter {
router: service.clone(),
metadata,
},
)
.into_owned(),
),
@@ -300,26 +556,10 @@ impl<A: Accept + Send + Sync + 'static> WebServer<A> {
}
pub fn serve_router(&mut self, router: Router) {
self.router.send(Some(router))
self.router.send(router)
}
pub fn serve_main(&mut self, ctx: RpcContext) {
self.serve_router(main_ui_router(ctx))
}
pub fn serve_setup(&mut self, ctx: SetupContext) {
self.serve_router(setup_ui_router(ctx))
}
pub fn serve_diagnostic(&mut self, ctx: DiagnosticContext) {
self.serve_router(diagnostic_ui_router(ctx))
}
pub fn serve_install(&mut self, ctx: InstallContext) {
self.serve_router(install_ui_router(ctx))
}
pub fn serve_init(&mut self, ctx: InitContext) {
self.serve_router(init_ui_router(ctx))
pub fn serve_ui_for<C: UiContext>(&mut self, ctx: C) {
self.serve_router(ui_router(ctx))
}
}

View File

@@ -1017,6 +1017,31 @@ pub async fn synchronize_network_manager<P: AsRef<Path>>(
.await?;
}
Command::new("ip")
.arg("rule")
.arg("add")
.arg("pref")
.arg("1000")
.arg("from")
.arg("all")
.arg("lookup")
.arg("main")
.invoke(ErrorKind::Network)
.await
.log_err();
Command::new("ip")
.arg("rule")
.arg("add")
.arg("pref")
.arg("1100")
.arg("from")
.arg("all")
.arg("lookup")
.arg("default")
.invoke(ErrorKind::Network)
.await
.log_err();
Command::new("systemctl")
.arg("restart")
.arg("NetworkManager")

View File

@@ -3,13 +3,13 @@ use std::fmt;
use std::str::FromStr;
use chrono::{DateTime, Utc};
use clap::builder::ValueParserFactory;
use clap::Parser;
use clap::builder::ValueParserFactory;
use color_eyre::eyre::eyre;
use helpers::const_true;
use imbl_value::InternedString;
use models::{FromStrParser, PackageId};
use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler};
use rpc_toolkit::{Context, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use tracing::instrument;
use ts_rs::TS;
@@ -463,7 +463,8 @@ pub fn notify<T: NotificationType>(
data,
seen: false,
},
)
)?;
Ok(())
}
#[test]

View File

@@ -361,6 +361,7 @@ pub async fn execute<C: Context>(
match ARCH {
"x86_64" => install.arg("--target=x86_64-efi"),
"aarch64" => install.arg("--target=arm64-efi"),
"riscv64" => install.arg("--target=riscv64-efi"),
_ => &mut install,
};
}

View File

@@ -5,6 +5,7 @@ use std::sync::Arc;
use chrono::Utc;
use clap::Parser;
use http::HeaderMap;
use imbl_value::InternedString;
use patch_db::PatchDb;
use reqwest::{Client, Proxy};
@@ -168,25 +169,38 @@ impl CallRemote<RegistryContext> for CliContext {
let url = if let Some(url) = self.registry_url.clone() {
url
} else if self.registry_hostname.is_some() {
format!(
let mut url: Url = format!(
"http://{}",
self.registry_listen.unwrap_or(DEFAULT_REGISTRY_LISTEN)
)
.parse()
.map_err(Error::from)?
.map_err(Error::from)?;
url.path_segments_mut()
.map_err(|_| Error::new(eyre!("cannot extend URL path"), ErrorKind::ParseUrl))?
.push("rpc")
.push("v0");
url
} else {
return Err(
Error::new(eyre!("`--registry` required"), ErrorKind::InvalidRequest).into(),
);
};
method = method.strip_prefix("registry.").unwrap_or(method);
let sig_context = self
.registry_hostname
.clone()
.or(url.host().as_ref().map(InternedString::from_display))
.or_not_found("registry hostname")?;
.or_else(|| url.host().as_ref().map(InternedString::from_display));
crate::middleware::signature::call_remote(self, url, &sig_context, method, params).await
crate::middleware::signature::call_remote(
self,
url,
HeaderMap::new(),
sig_context.as_deref(),
method,
params,
)
.await
}
}
@@ -195,61 +209,32 @@ impl CallRemote<RegistryContext, RegistryUrlParams> for RpcContext {
&self,
mut method: &str,
params: Value,
RegistryUrlParams { registry }: RegistryUrlParams,
RegistryUrlParams { mut registry }: RegistryUrlParams,
) -> Result<Value, RpcError> {
use reqwest::Method;
use reqwest::header::{ACCEPT, CONTENT_LENGTH, CONTENT_TYPE};
use rpc_toolkit::RpcResponse;
use rpc_toolkit::yajrc::{GenericRpcMethod, Id, RpcRequest};
let mut headers = HeaderMap::new();
headers.insert(
DEVICE_INFO_HEADER,
DeviceInfo::load(self).await?.to_header_value(),
);
registry
.path_segments_mut()
.map_err(|_| Error::new(eyre!("cannot extend URL path"), ErrorKind::ParseUrl))?
.push("rpc")
.push("v0");
let url = registry.join("rpc/v0")?;
method = method.strip_prefix("registry.").unwrap_or(method);
let sig_context = registry.host_str().map(InternedString::from);
let rpc_req = RpcRequest {
id: Some(Id::Number(0.into())),
method: GenericRpcMethod::<_, _, Value>::new(method),
crate::middleware::signature::call_remote(
self,
registry,
headers,
sig_context.as_deref(),
method,
params,
};
let body = serde_json::to_vec(&rpc_req)?;
let res = self
.client
.request(Method::POST, url)
.header(CONTENT_TYPE, "application/json")
.header(ACCEPT, "application/json")
.header(CONTENT_LENGTH, body.len())
.header(
DEVICE_INFO_HEADER,
DeviceInfo::load(self).await?.to_header_value(),
)
.body(body)
.send()
.await?;
if !res.status().is_success() {
let status = res.status();
let txt = res.text().await?;
let mut res = Err(Error::new(
eyre!("{}", status.canonical_reason().unwrap_or(status.as_str())),
ErrorKind::Network,
));
if !txt.is_empty() {
res = res.with_ctx(|_| (ErrorKind::Network, txt));
}
return res.map_err(From::from);
}
match res
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
{
Some("application/json") => {
serde_json::from_slice::<RpcResponse>(&*res.bytes().await?)
.with_kind(ErrorKind::Deserialization)?
.result
}
_ => Err(Error::new(eyre!("unknown content type"), ErrorKind::Network).into()),
}
)
.await
}
}

View File

@@ -175,7 +175,7 @@ impl Middleware<RegistryContext> for DeviceInfoMiddleware {
async move {
if metadata.get_device_info {
if let Some(device_info) = &self.device_info {
request.params["__device_info"] =
request.params["__DeviceInfo_device_info"] =
to_value(&DeviceInfo::from_header_value(device_info)?)?;
}
}

View File

@@ -141,9 +141,3 @@ pub fn registry_router(ctx: RegistryContext) -> Router {
}),
)
}
impl<A: Accept + Send + Sync + 'static> WebServer<A> {
pub fn serve_registry(&mut self, ctx: RegistryContext) {
self.serve_router(registry_router(ctx))
}
}

View File

@@ -83,7 +83,7 @@ pub struct AddAssetParams {
pub platform: InternedString,
#[ts(type = "string")]
pub url: Url,
#[serde(rename = "__auth_signer")]
#[serde(rename = "__Auth_signer")]
#[ts(skip)]
pub signer: AnyVerifyingKey,
pub signature: AnySignature,
@@ -289,7 +289,7 @@ pub struct RemoveAssetParams {
pub version: Version,
#[ts(type = "string")]
pub platform: InternedString,
#[serde(rename = "__auth_signer")]
#[serde(rename = "__Auth_signer")]
#[ts(skip)]
pub signer: AnyVerifyingKey,
}

View File

@@ -56,7 +56,7 @@ pub struct SignAssetParams {
#[ts(type = "string")]
platform: InternedString,
#[ts(skip)]
#[serde(rename = "__auth_signer")]
#[serde(rename = "__Auth_signer")]
signer: AnyVerifyingKey,
signature: AnySignature,
}

View File

@@ -68,7 +68,7 @@ pub struct AddVersionParams {
pub source_version: VersionRange,
#[arg(skip)]
#[ts(skip)]
#[serde(rename = "__auth_signer")]
#[serde(rename = "__Auth_signer")]
pub signer: Option<AnyVerifyingKey>,
}
@@ -146,7 +146,7 @@ pub struct GetOsVersionParams {
platform: Option<InternedString>,
#[ts(skip)]
#[arg(skip)]
#[serde(rename = "__device_info")]
#[serde(rename = "__DeviceInfo_device_info")]
pub device_info: Option<DeviceInfo>,
}

View File

@@ -31,7 +31,7 @@ pub struct AddPackageParams {
#[ts(type = "string")]
pub url: Url,
#[ts(skip)]
#[serde(rename = "__auth_signer")]
#[serde(rename = "__Auth_signer")]
pub uploader: AnyVerifyingKey,
pub commitment: MerkleArchiveCommitment,
pub signature: AnySignature,
@@ -169,7 +169,7 @@ pub struct RemovePackageParams {
pub version: VersionString,
#[ts(skip)]
#[arg(skip)]
#[serde(rename = "__auth_signer")]
#[serde(rename = "__Auth_signer")]
pub signer: Option<AnyVerifyingKey>,
}

View File

@@ -51,7 +51,7 @@ pub struct GetPackageParams {
pub source_version: Option<VersionString>,
#[ts(skip)]
#[arg(skip)]
#[serde(rename = "__device_info")]
#[serde(rename = "__DeviceInfo_device_info")]
pub device_info: Option<DeviceInfo>,
#[serde(default)]
#[arg(default_value = "none")]

View File

@@ -3,12 +3,10 @@ use tokio::io::{AsyncSeek, AsyncWrite};
use crate::prelude::*;
use crate::util::io::TrackingIO;
#[async_trait::async_trait]
pub trait Sink: AsyncWrite + Unpin + Send {
async fn current_position(&mut self) -> Result<u64, Error>;
fn current_position(&mut self) -> impl Future<Output = Result<u64, Error>> + Send + '_;
}
#[async_trait::async_trait]
impl<S: AsyncWrite + AsyncSeek + Unpin + Send> Sink for S {
async fn current_position(&mut self) -> Result<u64, Error> {
use tokio::io::AsyncSeekExt;
@@ -17,7 +15,6 @@ impl<S: AsyncWrite + AsyncSeek + Unpin + Send> Sink for S {
}
}
#[async_trait::async_trait]
impl<W: AsyncWrite + Unpin + Send> Sink for TrackingIO<W> {
async fn current_position(&mut self) -> Result<u64, Error> {
Ok(self.position())

View File

@@ -16,10 +16,10 @@ use crate::s9pk::merkle_archive::source::TmpSource;
use crate::s9pk::merkle_archive::{Entry, MerkleArchive};
use crate::s9pk::v1::manifest::{Manifest as ManifestV1, PackageProcedure};
use crate::s9pk::v1::reader::S9pkReader;
use crate::s9pk::v2::pack::{ImageSource, PackSource, CONTAINER_TOOL};
use crate::s9pk::v2::pack::{CONTAINER_TOOL, ImageSource, PackSource};
use crate::s9pk::v2::{S9pk, SIG_CONTEXT};
use crate::util::io::{create_file, TmpDir};
use crate::util::Invoke;
use crate::util::io::{TmpDir, create_file};
pub const MAGIC_AND_VERSION: &[u8] = &[0x3b, 0x3b, 0x01];

View File

@@ -3,7 +3,7 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use clap::Parser;
use futures::future::{ready, BoxFuture};
use futures::future::{BoxFuture, ready};
use futures::{FutureExt, TryStreamExt};
use imbl_value::InternedString;
use models::{DataUrl, ImageId, PackageId, VersionString};
@@ -18,20 +18,20 @@ use crate::context::CliContext;
use crate::dependencies::{DependencyMetadata, MetadataSrc};
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::s9pk::S9pk;
use crate::s9pk::git_hash::GitHash;
use crate::s9pk::manifest::Manifest;
use crate::s9pk::merkle_archive::directory_contents::DirectoryContents;
use crate::s9pk::merkle_archive::source::http::HttpSource;
use crate::s9pk::merkle_archive::source::multi_cursor_file::MultiCursorFile;
use crate::s9pk::merkle_archive::source::{
into_dyn_read, ArchiveSource, DynFileSource, DynRead, FileSource, TmpSource,
ArchiveSource, DynFileSource, DynRead, FileSource, TmpSource, into_dyn_read,
};
use crate::s9pk::merkle_archive::{Entry, MerkleArchive};
use crate::s9pk::v2::SIG_CONTEXT;
use crate::s9pk::S9pk;
use crate::util::io::{create_file, open_file, TmpDir};
use crate::util::io::{TmpDir, create_file, open_file};
use crate::util::serde::IoFormat;
use crate::util::{new_guid, Invoke, PathOrUrl};
use crate::util::{Invoke, PathOrUrl, new_guid};
#[cfg(not(feature = "docker"))]
pub const CONTAINER_TOOL: &str = "podman";
@@ -369,10 +369,12 @@ impl ImageSource {
workdir,
..
} => {
vec![workdir
.as_deref()
.unwrap_or(Path::new("."))
.join(dockerfile.as_deref().unwrap_or(Path::new("Dockerfile")))]
vec![
workdir
.as_deref()
.unwrap_or(Path::new("."))
.join(dockerfile.as_deref().unwrap_or(Path::new("Dockerfile"))),
]
}
Self::DockerTag(_) => Vec::new(),
}
@@ -414,6 +416,8 @@ impl ImageSource {
"--platform=linux/amd64".to_owned()
} else if arch == "aarch64" {
"--platform=linux/arm64".to_owned()
} else if arch == "riscv64" {
"--platform=linux/riscv64".to_owned()
} else {
format!("--platform=linux/{arch}")
};
@@ -476,6 +480,8 @@ impl ImageSource {
"--platform=linux/amd64".to_owned()
} else if arch == "aarch64" {
"--platform=linux/arm64".to_owned()
} else if arch == "riscv64" {
"--platform=linux/riscv64".to_owned()
} else {
format!("--platform=linux/{arch}")
};

View File

@@ -6,7 +6,7 @@ use std::time::{Duration, SystemTime};
use clap::Parser;
use futures::future::join_all;
use helpers::NonDetachingJoinHandle;
use imbl::{vector, Vector};
use imbl::{Vector, vector};
use imbl_value::InternedString;
use models::{HostId, PackageId, ServiceInterfaceId};
use serde::{Deserialize, Serialize};

View File

@@ -95,7 +95,7 @@ pub async fn get_ssl_certificate(
.as_entries()?
.into_iter()
.flat_map(|(_, net)| net.as_ip_info().transpose_ref())
.flat_map(|net| net.as_subnets().de().log_err())
.flat_map(|net| net.as_deref().as_subnets().de().log_err())
.flatten()
.any(|s| s.addr() == ip)
{

View File

@@ -15,12 +15,12 @@ use futures::future::BoxFuture;
use futures::stream::FusedStream;
use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt};
use helpers::NonDetachingJoinHandle;
use imbl_value::{json, InternedString};
use imbl_value::{InternedString, json};
use itertools::Itertools;
use models::{ActionId, HostId, ImageId, PackageId};
use nix::sys::signal::Signal;
use persistent_container::{PersistentContainer, Subcontainer};
use rpc_toolkit::{from_fn_async, CallRemoteHandler, Empty, HandlerArgs, HandlerFor};
use rpc_toolkit::{HandlerArgs, HandlerFor};
use serde::{Deserialize, Serialize};
use service_actor::ServiceActor;
use start_stop::StartStop;
@@ -47,11 +47,11 @@ use crate::service::action::update_tasks;
use crate::service::rpc::{ExitParams, InitKind};
use crate::service::service_map::InstallProgressHandles;
use crate::service::uninstall::cleanup;
use crate::util::Never;
use crate::util::actor::concurrent::ConcurrentActor;
use crate::util::io::{create_file, delete_file, AsyncReadStream, TermSize};
use crate::util::io::{AsyncReadStream, TermSize, create_file, delete_file};
use crate::util::net::WebSocketExt;
use crate::util::serde::Pem;
use crate::util::Never;
use crate::volume::data_dir;
use crate::{CAP_1_KiB, DATA_DIR};
@@ -707,57 +707,6 @@ pub async fn rebuild(ctx: RpcContext, RebuildParams { id }: RebuildParams) -> Re
Ok(())
}
#[derive(Deserialize, Serialize, Parser, TS)]
pub struct ConnectParams {
pub id: PackageId,
}
pub async fn connect_rpc(
ctx: RpcContext,
ConnectParams { id }: ConnectParams,
) -> Result<Guid, Error> {
let id_ref = &id;
crate::lxc::connect(
&ctx,
ctx.services
.get(&id)
.await
.as_ref()
.or_not_found(lazy_format!("service for {id_ref}"))?
.seed
.persistent_container
.lxc_container
.get()
.or_not_found(lazy_format!("container for {id_ref}"))?,
)
.await
}
pub async fn connect_rpc_cli(
HandlerArgs {
context,
parent_method,
method,
params,
inherited_params,
raw_params,
}: HandlerArgs<CliContext, ConnectParams>,
) -> Result<(), Error> {
let ctx = context.clone();
let guid = CallRemoteHandler::<CliContext, _, _>::new(from_fn_async(connect_rpc))
.handle_async(HandlerArgs {
context,
parent_method,
method,
params: rpc_toolkit::util::Flat(params, Empty {}),
inherited_params,
raw_params,
})
.await?;
crate::lxc::connect_cli(&ctx, guid).await
}
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct AttachParams {
@@ -768,7 +717,7 @@ pub struct AttachParams {
pub stderr_tty: bool,
pub pty_size: Option<TermSize>,
#[ts(skip)]
#[serde(rename = "__auth_session")]
#[serde(rename = "__Auth_session")]
session: Option<InternedString>,
#[ts(type = "string | null")]
subcontainer: Option<InternedString>,

View File

@@ -1,6 +1,8 @@
use std::sync::Arc;
use std::time::Duration;
use futures::FutureExt;
use futures::future::{BoxFuture, Either};
use imbl::vector;
use super::ServiceActorSeed;
@@ -16,131 +18,162 @@ use crate::util::actor::background::BackgroundJobQueue;
#[derive(Clone)]
pub(super) struct ServiceActor(pub(super) Arc<ServiceActorSeed>);
enum ServiceActorLoopNext {
Wait,
DontWait,
}
impl Actor for ServiceActor {
fn init(&mut self, jobs: &BackgroundJobQueue) {
let seed = self.0.clone();
let mut current = seed.persistent_container.state.subscribe();
jobs.add_job(async move {
let _ = current.wait_for(|s| s.rt_initialized).await;
let mut start_stop_task: Option<Either<_, _>> = None;
loop {
match service_actor_loop(&current, &seed).await {
ServiceActorLoopNext::Wait => tokio::select! {
_ = current.changed() => (),
},
ServiceActorLoopNext::DontWait => (),
}
let wait = match service_actor_loop(&current, &seed, &mut start_stop_task).await {
Ok(()) => Either::Right(current.changed().then(|res| async move {
match res {
Ok(()) => (),
Err(_) => futures::future::pending().await,
}
})),
Err(e) => {
tracing::error!("error synchronizing state of service: {e}");
tracing::debug!("{e:?}");
seed.synchronized.notify_waiters();
tracing::error!("Retrying in {}s...", SYNC_RETRY_COOLDOWN_SECONDS);
Either::Left(tokio::time::sleep(Duration::from_secs(
SYNC_RETRY_COOLDOWN_SECONDS,
)))
}
};
tokio::pin!(wait);
let start_stop_handler = async {
match &mut start_stop_task {
Some(task) => {
let err = task.await.log_err().is_none(); // TODO: ideally this error should be sent to service logs
start_stop_task.take();
if err {
tokio::time::sleep(Duration::from_secs(
SYNC_RETRY_COOLDOWN_SECONDS,
))
.await;
}
}
_ => futures::future::pending().await,
}
};
tokio::pin!(start_stop_handler);
futures::future::select(wait, start_stop_handler).await;
}
});
}
}
async fn service_actor_loop(
async fn service_actor_loop<'a>(
current: &tokio::sync::watch::Receiver<super::persistent_container::ServiceState>,
seed: &Arc<ServiceActorSeed>,
) -> ServiceActorLoopNext {
seed: &'a Arc<ServiceActorSeed>,
start_stop_task: &mut Option<
Either<BoxFuture<'a, Result<(), Error>>, BoxFuture<'a, Result<(), Error>>>,
>,
) -> Result<(), Error> {
let id = &seed.id;
let kinds = current.borrow().kinds();
if let Err(e) = async {
let major_changes_state = seed
.ctx
.db
.mutate(|d| {
if let Some(i) = d.as_public_mut().as_package_data_mut().as_idx_mut(&id) {
let previous = i.as_status().de()?;
let main_status = match &kinds {
ServiceStateKinds {
transition_state: Some(TransitionKind::Restarting),
..
} => MainStatus::Restarting,
ServiceStateKinds {
transition_state: Some(TransitionKind::BackingUp),
..
} => previous.backing_up(),
ServiceStateKinds {
running_status: Some(status),
desired_state: StartStop::Start,
..
} => MainStatus::Running {
started: status.started,
health: previous.health().cloned().unwrap_or_default(),
},
ServiceStateKinds {
running_status: None,
desired_state: StartStop::Start,
..
} => MainStatus::Starting {
health: previous.health().cloned().unwrap_or_default(),
},
ServiceStateKinds {
running_status: Some(_),
desired_state: StartStop::Stop,
..
} => MainStatus::Stopping,
ServiceStateKinds {
running_status: None,
desired_state: StartStop::Stop,
..
} => MainStatus::Stopped,
};
i.as_status_mut().ser(&main_status)?;
return Ok(previous
.major_changes(&main_status)
.then_some((previous, main_status)));
}
Ok(None)
})
.await
.result?;
if let Some((previous, new_state)) = major_changes_state {
if let Some(callbacks) = seed.ctx.callbacks.get_status(id) {
callbacks
.call(vector![to_value(&previous)?, to_value(&new_state)?])
.await?;
let major_changes_state = seed
.ctx
.db
.mutate(|d| {
if let Some(i) = d.as_public_mut().as_package_data_mut().as_idx_mut(&id) {
let previous = i.as_status().de()?;
let main_status = match &kinds {
ServiceStateKinds {
transition_state: Some(TransitionKind::Restarting),
..
} => MainStatus::Restarting,
ServiceStateKinds {
transition_state: Some(TransitionKind::BackingUp),
..
} => previous.backing_up(),
ServiceStateKinds {
running_status: Some(status),
desired_state: StartStop::Start,
..
} => MainStatus::Running {
started: status.started,
health: previous.health().cloned().unwrap_or_default(),
},
ServiceStateKinds {
running_status: None,
desired_state: StartStop::Start,
..
} => MainStatus::Starting {
health: previous.health().cloned().unwrap_or_default(),
},
ServiceStateKinds {
running_status: Some(_),
desired_state: StartStop::Stop,
..
} => MainStatus::Stopping,
ServiceStateKinds {
running_status: None,
desired_state: StartStop::Stop,
..
} => MainStatus::Stopped,
};
i.as_status_mut().ser(&main_status)?;
return Ok(previous
.major_changes(&main_status)
.then_some((previous, main_status)));
}
Ok(None)
})
.await
.result?;
if let Some((previous, new_state)) = major_changes_state {
if let Some(callbacks) = seed.ctx.callbacks.get_status(id) {
callbacks
.call(vector![to_value(&previous)?, to_value(&new_state)?])
.await?;
}
seed.synchronized.notify_waiters();
match kinds {
ServiceStateKinds {
running_status: None,
desired_state: StartStop::Start,
..
} => {
seed.persistent_container.start().await?;
}
ServiceStateKinds {
running_status: Some(_),
desired_state: StartStop::Stop,
..
} => {
seed.persistent_container.stop().await?;
seed.persistent_container
.state
.send_if_modified(|s| s.running_status.take().is_some());
}
_ => (),
};
Ok::<_, Error>(())
}
.await
{
tracing::error!("error synchronizing state of service: {e}");
tracing::debug!("{e:?}");
seed.synchronized.notify_waiters();
tracing::error!("Retrying in {}s...", SYNC_RETRY_COOLDOWN_SECONDS);
tokio::time::sleep(Duration::from_secs(SYNC_RETRY_COOLDOWN_SECONDS)).await;
return ServiceActorLoopNext::DontWait;
}
seed.synchronized.notify_waiters();
ServiceActorLoopNext::Wait
match kinds {
ServiceStateKinds {
running_status: None,
desired_state: StartStop::Start,
..
} => {
let task = start_stop_task
.take()
.filter(|task| matches!(task, Either::Right(_)));
*start_stop_task = Some(
task.unwrap_or_else(|| Either::Right(seed.persistent_container.start().boxed())),
);
}
ServiceStateKinds {
running_status: Some(_),
desired_state: StartStop::Stop,
..
} => {
let task = start_stop_task
.take()
.filter(|task| matches!(task, Either::Left(_)));
*start_stop_task = Some(task.unwrap_or_else(|| {
Either::Left(
async {
seed.persistent_container.stop().await?;
seed.persistent_container
.state
.send_if_modified(|s| s.running_status.take().is_some());
Ok::<_, Error>(())
}
.boxed(),
)
}));
}
_ => (),
};
Ok(())
}

View File

@@ -1,6 +1,5 @@
use std::collections::BTreeSet;
use std::fmt;
use std::sync::Arc;
use std::time::Duration;
use chrono::Utc;
@@ -10,8 +9,6 @@ use futures::{FutureExt, TryStreamExt};
use imbl::vector;
use imbl_value::InternedString;
use rpc_toolkit::{Context, Empty, HandlerExt, ParentHandler, from_fn_async};
use rustls::RootCertStore;
use rustls_pki_types::CertificateDer;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tokio::process::Command;
use tokio::sync::broadcast::Receiver;
@@ -1024,7 +1021,7 @@ pub struct TestSmtpParams {
#[arg(long)]
pub login: String,
#[arg(long)]
pub password: Option<String>,
pub password: String,
}
pub async fn test_smtp(
_: RpcContext,
@@ -1037,74 +1034,23 @@ pub async fn test_smtp(
password,
}: TestSmtpParams,
) -> Result<(), Error> {
#[cfg(feature = "mail-send")]
{
use mail_send::SmtpClientBuilder;
use mail_send::mail_builder::{self, MessageBuilder};
use rustls_pki_types::pem::PemObject;
use lettre::message::header::ContentType;
use lettre::transport::smtp::authentication::Credentials;
use lettre::{AsyncSmtpTransport, AsyncTransport, Message, Tokio1Executor};
let Some(pass_val) = password else {
return Err(Error::new(
eyre!("mail-send requires a password"),
ErrorKind::InvalidRequest,
));
};
let mut root_cert_store = RootCertStore::empty();
let pem = tokio::fs::read("/etc/ssl/certs/ca-certificates.crt").await?;
for cert in CertificateDer::pem_slice_iter(&pem) {
root_cert_store.add_parsable_certificates([cert.with_kind(ErrorKind::OpenSsl)?]);
}
let cfg = Arc::new(
rustls::ClientConfig::builder_with_provider(Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_safe_default_protocol_versions()?
.with_root_certificates(root_cert_store)
.with_no_client_auth(),
);
let client = SmtpClientBuilder::new_with_tls_config(server, port, cfg)
.implicit_tls(false)
.credentials((login.split("@").next().unwrap().to_owned(), pass_val));
fn parse_address<'a>(addr: &'a str) -> mail_builder::headers::address::Address<'a> {
if addr.find("<").map_or(false, |start| {
addr.find(">").map_or(false, |end| start < end)
}) {
addr.split_once("<")
.map(|(name, addr)| (name.trim(), addr.strip_suffix(">").unwrap_or(addr)))
.unwrap()
.into()
} else {
addr.into()
}
}
let message = MessageBuilder::new()
.from(parse_address(&from))
.to(parse_address(&to))
.subject("StartOS Test Email")
.text_body("This is a test email sent from your StartOS Server");
client
.connect()
.await
.map_err(|e| {
Error::new(
eyre!("mail-send connection error: {:?}", e),
ErrorKind::Unknown,
)
})?
.send(message)
.await
.map_err(|e| Error::new(eyre!("mail-send send error: {:?}", e), ErrorKind::Unknown))?;
Ok(())
}
#[cfg(not(feature = "mail-send"))]
Err(Error::new(
eyre!("test-smtp requires mail-send feature to be enabled"),
ErrorKind::InvalidRequest,
))
AsyncSmtpTransport::<Tokio1Executor>::relay(&server)?
.credentials(Credentials::new(login, password))
.build()
.send(
Message::builder()
.from(from.parse()?)
.to(to.parse()?)
.subject("StartOS Test Email")
.header(ContentType::TEXT_PLAIN)
.body("This is a test email sent from your StartOS Server".to_owned())?,
)
.await?;
Ok(())
}
#[tokio::test]

View File

@@ -1,49 +1,61 @@
use std::net::Ipv4Addr;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use clap::Parser;
use imbl_value::InternedString;
use ipnet::Ipv4Net;
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerExt, ParentHandler};
use rpc_toolkit::{Context, Empty, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use crate::context::CliContext;
use crate::prelude::*;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::wg::WgSubnetConfig;
use crate::tunnel::wg::{WgConfig, WgSubnetClients, WgSubnetConfig};
use crate::util::serde::{HandlerExtSerde, display_serializable};
pub fn tunnel_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand("web", super::web::web_api::<C>())
.subcommand(
"db",
super::db::db_api::<C>()
.with_about("Commands to interact with the db i.e. dump and apply"),
)
.subcommand(
"auth",
super::auth::auth_api::<C>().with_about("Add or remove authorized clients"),
)
.subcommand(
"subnet",
subnet_api::<C>().with_about("Add, remove, or modify subnets"),
)
// .subcommand(
// "port-forward",
// ParentHandler::<C>::new()
// .subcommand(
// "add",
// from_fn_async(add_forward)
// .with_metadata("sync_db", Value::Bool(true))
// .no_display()
// .with_about("Add a new port forward")
// .with_call_remote::<CliContext>(),
// )
// .subcommand(
// "remove",
// from_fn_async(remove_forward)
// .with_metadata("sync_db", Value::Bool(true))
// .no_display()
// .with_about("Remove a port forward")
// .with_call_remote::<CliContext>(),
// ),
// )
.subcommand(
"device",
device_api::<C>().with_about("Add, remove, or list devices in subnets"),
)
.subcommand(
"port-forward",
ParentHandler::<C>::new()
.subcommand(
"add",
from_fn_async(add_forward)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Add a new port forward")
.with_call_remote::<CliContext>(),
)
.subcommand(
"remove",
from_fn_async(remove_forward)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Remove a port forward")
.with_call_remote::<CliContext>(),
),
)
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct SubnetParams {
subnet: Ipv4Net,
}
@@ -68,44 +80,89 @@ pub fn subnet_api<C: Context>() -> ParentHandler<C, SubnetParams> {
.with_about("Remove a subnet")
.with_call_remote::<CliContext>(),
)
// .subcommand(
// "set-default-forward-target",
// from_fn_async(set_default_forward_target)
// .with_metadata("sync_db", Value::Bool(true))
// .no_display()
// .with_about("Set the default target for port forwarding")
// .with_call_remote::<CliContext>(),
// )
// .subcommand(
// "add-device",
// from_fn_async(add_device)
// .with_metadata("sync_db", Value::Bool(true))
// .no_display()
// .with_about("Add a device to a subnet")
// .with_call_remote::<CliContext>(),
// )
// .subcommand(
// "remove-device",
// from_fn_async(remove_device)
// .with_metadata("sync_db", Value::Bool(true))
// .no_display()
// .with_about("Remove a device from a subnet")
// .with_call_remote::<CliContext>(),
// )
}
pub fn device_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand(
"add",
from_fn_async(add_device)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Add a device to a subnet")
.with_call_remote::<CliContext>(),
)
.subcommand(
"remove",
from_fn_async(remove_device)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Remove a device from a subnet")
.with_call_remote::<CliContext>(),
)
.subcommand(
"list",
from_fn_async(list_devices)
.with_display_serializable()
.with_custom_display_fn(|HandlerArgs { params, .. }, res| {
use prettytable::*;
if let Some(format) = params.format {
return display_serializable(format, res);
}
let mut table = Table::new();
table.add_row(row![bc => "NAME", "IP", "PUBLIC KEY"]);
for (ip, config) in res.clients.0 {
table.add_row(row![config.name, ip, config.key.verifying_key()]);
}
table.print_tty(false)?;
Ok(())
})
.with_about("List devices in a subnet")
.with_call_remote::<CliContext>(),
)
.subcommand(
"show-config",
from_fn_async(show_config)
.with_about("Show the WireGuard configuration for a device")
.with_call_remote::<CliContext>(),
)
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct AddSubnetParams {
name: InternedString,
}
pub async fn add_subnet(
ctx: TunnelContext,
_: Empty,
SubnetParams { subnet }: SubnetParams,
AddSubnetParams { name }: AddSubnetParams,
SubnetParams { mut subnet }: SubnetParams,
) -> Result<(), Error> {
if subnet.prefix_len() > 24 {
return Err(Error::new(
eyre!("invalid subnet"),
ErrorKind::InvalidRequest,
));
}
let addr = subnet
.hosts()
.next()
.ok_or_else(|| Error::new(eyre!("invalid subnet"), ErrorKind::InvalidRequest))?;
subnet = Ipv4Net::new_assert(addr, subnet.prefix_len());
let server = ctx
.db
.mutate(|db| {
let map = db.as_wg_mut().as_subnets_mut();
if !map.contains_key(&subnet)? {
map.insert(&subnet, &WgSubnetConfig::new())?;
}
map.upsert(&subnet, || {
Ok(WgSubnetConfig::new(InternedString::default()))
})?
.as_name_mut()
.ser(&name)?;
db.as_wg().de()
})
.await
@@ -128,3 +185,221 @@ pub async fn remove_subnet(
.result?;
server.sync().await
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct AddDeviceParams {
subnet: Ipv4Net,
name: InternedString,
ip: Option<Ipv4Addr>,
}
pub async fn add_device(
ctx: TunnelContext,
AddDeviceParams { subnet, name, ip }: AddDeviceParams,
) -> Result<(), Error> {
let server = ctx
.db
.mutate(|db| {
db.as_wg_mut()
.as_subnets_mut()
.as_idx_mut(&subnet)
.or_not_found(&subnet)?
.as_clients_mut()
.mutate(|WgSubnetClients(clients)| {
let ip = if let Some(ip) = ip {
ip
} else {
subnet
.hosts()
.find(|ip| !clients.contains_key(ip) && *ip != subnet.addr())
.ok_or_else(|| {
Error::new(
eyre!("no available ips in subnet"),
ErrorKind::InvalidRequest,
)
})?
};
if ip.octets()[3] == 0 || ip.octets()[3] == 255 {
return Err(Error::new(eyre!("invalid ip"), ErrorKind::InvalidRequest));
}
if ip == subnet.addr() {
return Err(Error::new(eyre!("invalid ip"), ErrorKind::InvalidRequest));
}
if !subnet.contains(&ip) {
return Err(Error::new(
eyre!("ip not in subnet"),
ErrorKind::InvalidRequest,
));
}
let client = clients
.entry(ip)
.or_insert_with(|| WgConfig::generate(name.clone()));
client.name = name;
Ok(())
})?;
db.as_wg().de()
})
.await
.result?;
server.sync().await
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct RemoveDeviceParams {
subnet: Ipv4Net,
ip: Ipv4Addr,
}
pub async fn remove_device(
ctx: TunnelContext,
RemoveDeviceParams { subnet, ip }: RemoveDeviceParams,
) -> Result<(), Error> {
let server = ctx
.db
.mutate(|db| {
db.as_wg_mut()
.as_subnets_mut()
.as_idx_mut(&subnet)
.or_not_found(&subnet)?
.as_clients_mut()
.remove(&ip)?
.or_not_found(&ip)?;
db.as_wg().de()
})
.await
.result?;
server.sync().await
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct ListDevicesParams {
subnet: Ipv4Net,
}
pub async fn list_devices(
ctx: TunnelContext,
ListDevicesParams { subnet }: ListDevicesParams,
) -> Result<WgSubnetConfig, Error> {
ctx.db
.peek()
.await
.as_wg()
.as_subnets()
.as_idx(&subnet)
.or_not_found(&subnet)?
.de()
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct ShowConfigParams {
subnet: Ipv4Net,
ip: Ipv4Addr,
wan_addr: Option<IpAddr>,
#[serde(rename = "__ConnectInfo_local_addr")]
#[arg(skip)]
local_addr: Option<SocketAddr>,
}
pub async fn show_config(
ctx: TunnelContext,
ShowConfigParams {
subnet,
ip,
wan_addr,
local_addr,
}: ShowConfigParams,
) -> Result<String, Error> {
let peek = ctx.db.peek().await;
let wg = peek.as_wg();
let client = wg
.as_subnets()
.as_idx(&subnet)
.or_not_found(&subnet)?
.as_clients()
.as_idx(&ip)
.or_not_found(&ip)?
.de()?;
let wan_addr = if let Some(wan_addr) = wan_addr.or(local_addr.map(|a| a.ip())).filter(|ip| {
!ip.is_loopback()
&& !match ip {
IpAddr::V4(ipv4) => ipv4.is_private() || ipv4.is_link_local(),
IpAddr::V6(ipv6) => ipv6.is_unique_local() || ipv6.is_unicast_link_local(),
}
}) {
wan_addr
} else if let Some(webserver) = peek.as_webserver().as_listen().de()? {
webserver.ip()
} else {
ctx.net_iface
.peek(|i| {
i.iter().find_map(|(_, info)| {
info.ip_info
.as_ref()
.filter(|_| info.public())
.iter()
.find_map(|info| info.subnets.iter().next())
.copied()
})
})
.or_not_found("a public IP address")?
.addr()
};
Ok(client
.client_config(
ip,
wg.as_key().de()?.verifying_key(),
(wan_addr, wg.as_port().de()?).into(),
)
.to_string())
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct AddPortForwardParams {
source: SocketAddrV4,
target: SocketAddrV4,
}
pub async fn add_forward(
ctx: TunnelContext,
AddPortForwardParams { source, target }: AddPortForwardParams,
) -> Result<(), Error> {
let rc = ctx.forward.add_forward(source, target).await?;
ctx.active_forwards.mutate(|m| {
m.insert(source, rc);
});
ctx.db
.mutate(|db| db.as_port_forwards_mut().insert(&source, &target))
.await
.result?;
Ok(())
}
#[derive(Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct RemovePortForwardParams {
source: SocketAddrV4,
}
pub async fn remove_forward(
ctx: TunnelContext,
RemovePortForwardParams { source, .. }: RemovePortForwardParams,
) -> Result<(), Error> {
ctx.db
.mutate(|db| db.as_port_forwards_mut().remove(&source))
.await
.result?;
if let Some(rc) = ctx.active_forwards.mutate(|m| m.remove(&source)) {
drop(rc);
ctx.forward.gc().await?;
}
Ok(())
}

View File

@@ -0,0 +1,323 @@
use clap::Parser;
use imbl::HashMap;
use imbl_value::InternedString;
use itertools::Itertools;
use patch_db::HasModel;
use rpc_toolkit::{Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async};
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use crate::auth::{Sessions, check_password};
use crate::context::CliContext;
use crate::middleware::auth::AuthContext;
use crate::middleware::signature::SignatureAuthContext;
use crate::prelude::*;
use crate::rpc_continuations::OpenAuthedContinuations;
use crate::sign::AnyVerifyingKey;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::db::TunnelDatabase;
use crate::util::serde::{HandlerExtSerde, display_serializable};
use crate::util::sync::SyncMutex;
impl SignatureAuthContext for TunnelContext {
type Database = TunnelDatabase;
type AdditionalMetadata = ();
type CheckPubkeyRes = ();
fn db(&self) -> &TypedPatchDb<Self::Database> {
&self.db
}
async fn sig_context(
&self,
) -> impl IntoIterator<Item = Result<impl AsRef<str> + Send, Error>> + Send {
let peek = self.db().peek().await;
peek.as_webserver()
.as_listen()
.de()
.map(|a| a.as_ref().map(InternedString::from_display))
.transpose()
.into_iter()
.chain(
std::iter::once_with(move || {
peek.as_webserver()
.as_certificate()
.de()
.ok()
.flatten()
.and_then(|cert_data| cert_data.cert.0.first().cloned())
.and_then(|cert| cert.subject_alt_names())
.into_iter()
.flatten()
.filter_map(|san| {
san.dnsname().map(InternedString::from).or_else(|| {
san.ipaddress().and_then(|ip_bytes| {
let ip: std::net::IpAddr = match ip_bytes.len() {
4 => std::net::IpAddr::V4(std::net::Ipv4Addr::from(
<[u8; 4]>::try_from(ip_bytes).ok()?,
)),
16 => std::net::IpAddr::V6(std::net::Ipv6Addr::from(
<[u8; 16]>::try_from(ip_bytes).ok()?,
)),
_ => return None,
};
Some(InternedString::from_display(&ip))
})
})
})
.map(Ok)
.collect::<Vec<_>>()
})
.flatten(),
)
}
fn check_pubkey(
db: &Model<Self::Database>,
pubkey: Option<&crate::sign::AnyVerifyingKey>,
_: Self::AdditionalMetadata,
) -> Result<Self::CheckPubkeyRes, Error> {
if let Some(pubkey) = pubkey {
if db.as_auth_pubkeys().de()?.contains_key(pubkey) {
return Ok(());
}
}
Err(Error::new(
eyre!("Key is not authorized"),
ErrorKind::IncorrectPassword,
))
}
async fn post_auth_hook(
&self,
_: Self::CheckPubkeyRes,
_: &rpc_toolkit::RpcRequest,
) -> Result<(), Error> {
Ok(())
}
}
impl AuthContext for TunnelContext {
const LOCAL_AUTH_COOKIE_PATH: &str = "/run/start-tunnel/rpc.authcookie";
const LOCAL_AUTH_COOKIE_OWNERSHIP: &str = "root:root";
fn access_sessions(db: &mut Model<Self::Database>) -> &mut Model<crate::auth::Sessions> {
db.as_sessions_mut()
}
fn ephemeral_sessions(&self) -> &SyncMutex<Sessions> {
&self.ephemeral_sessions
}
fn open_authed_continuations(&self) -> &OpenAuthedContinuations<Option<InternedString>> {
&self.open_authed_continuations
}
fn check_password(db: &Model<Self::Database>, password: &str) -> Result<(), Error> {
check_password(&db.as_password().de()?.unwrap_or_default(), password)
}
}
#[derive(Clone, Debug, Deserialize, Serialize, HasModel, TS, Parser)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
#[ts(export)]
pub struct SignerInfo {
pub name: InternedString,
}
pub fn auth_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand(
"login",
from_fn_async(crate::auth::login_impl::<TunnelContext>)
.with_metadata("login", Value::Bool(true))
.no_cli(),
)
.subcommand(
"logout",
from_fn_async(crate::auth::logout::<TunnelContext>)
.with_metadata("get_session", Value::Bool(true))
.no_display()
.with_about("Log out of current auth session")
.with_call_remote::<CliContext>(),
)
.subcommand("set-password", from_fn_async(set_password_rpc).no_cli())
.subcommand(
"set-password",
from_fn_async(set_password_cli)
.with_about("Set user interface password")
.no_display(),
)
.subcommand(
"reset-password",
from_fn_async(reset_password)
.with_about("Reset user interface password")
.no_display(),
)
.subcommand(
"key",
ParentHandler::<C>::new()
.subcommand(
"add",
from_fn_async(add_key)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Add a new authorized key")
.with_call_remote::<CliContext>(),
)
.subcommand(
"remove",
from_fn_async(remove_key)
.with_metadata("sync_db", Value::Bool(true))
.no_display()
.with_about("Remove an authorized key")
.with_call_remote::<CliContext>(),
)
.subcommand(
"list",
from_fn_async(list_keys)
.with_metadata("sync_db", Value::Bool(true))
.with_display_serializable()
.with_custom_display_fn(|HandlerArgs { params, .. }, res| {
use prettytable::*;
if let Some(format) = params.format {
return display_serializable(format, res);
}
let mut table = Table::new();
table.add_row(row![bc => "NAME", "KEY"]);
for (key, info) in res {
table.add_row(row![info.name, key]);
}
table.print_tty(false)?;
Ok(())
})
.with_about("List authorized keys")
.with_call_remote::<CliContext>(),
),
)
}
#[derive(Debug, Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct AddKeyParams {
pub name: InternedString,
pub key: AnyVerifyingKey,
}
pub async fn add_key(
ctx: TunnelContext,
AddKeyParams { name, key }: AddKeyParams,
) -> Result<(), Error> {
ctx.db
.mutate(|db| {
db.as_auth_pubkeys_mut().mutate(|auth_pubkeys| {
auth_pubkeys.insert(key, SignerInfo { name });
Ok(())
})
})
.await
.result
}
#[derive(Debug, Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct RemoveKeyParams {
pub key: AnyVerifyingKey,
}
pub async fn remove_key(
ctx: TunnelContext,
RemoveKeyParams { key }: RemoveKeyParams,
) -> Result<(), Error> {
ctx.db
.mutate(|db| {
db.as_auth_pubkeys_mut()
.mutate(|auth_pubkeys| Ok(auth_pubkeys.remove(&key)))
})
.await
.result?;
Ok(())
}
pub async fn list_keys(ctx: TunnelContext) -> Result<HashMap<AnyVerifyingKey, SignerInfo>, Error> {
ctx.db.peek().await.into_auth_pubkeys().de()
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SetPasswordParams {
pub password: String,
}
pub async fn set_password_rpc(
ctx: TunnelContext,
SetPasswordParams { password }: SetPasswordParams,
) -> Result<(), Error> {
let pwhash = argon2::hash_encoded(
password.as_bytes(),
&rand::random::<[u8; 16]>(),
&argon2::Config::rfc9106_low_mem(),
)
.with_kind(ErrorKind::PasswordHashGeneration)?;
ctx.db
.mutate(|db| db.as_password_mut().ser(&Some(pwhash)))
.await
.result?;
Ok(())
}
pub async fn set_password_cli(
HandlerArgs {
context,
parent_method,
method,
..
}: HandlerArgs<CliContext>,
) -> Result<(), Error> {
let password = rpassword::prompt_password("New Password: ")?;
let confirm = rpassword::prompt_password("Confirm Password: ")?;
if password != confirm {
return Err(Error::new(
eyre!("Passwords do not match"),
ErrorKind::InvalidRequest,
));
}
context
.call_remote::<TunnelContext>(
&parent_method.iter().chain(method.iter()).join("."),
to_value(&SetPasswordParams { password })?,
)
.await?;
println!("Password set successfully");
Ok(())
}
pub async fn reset_password(
HandlerArgs {
context,
parent_method,
method,
..
}: HandlerArgs<CliContext>,
) -> Result<(), Error> {
println!("Generating a random password...");
let params = SetPasswordParams {
password: base32::encode(
base32::Alphabet::Rfc4648Lower { padding: false },
&rand::random::<[u8; 16]>(),
),
};
context
.call_remote::<TunnelContext>(
&parent_method.iter().chain(method.iter()).join("."),
to_value(&params)?,
)
.await?;
println!("Your new password is:");
println!("{}", params.password);
Ok(())
}

View File

@@ -1,3 +1,5 @@
# StartTunnel config for {name}
[Interface]
Address = {addr}/24
PrivateKey = {privkey}
@@ -5,6 +7,6 @@ PrivateKey = {privkey}
[Peer]
PublicKey = {server_pubkey}
PresharedKey = {psk}
AllowedIPs = 0.0.0.0/0, ::/0
AllowedIPs = 0.0.0.0/0,::/0
Endpoint = {server_addr}
PersistentKeepalive = 25

View File

@@ -1,31 +1,44 @@
use std::collections::BTreeSet;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::collections::BTreeMap;
use std::net::{IpAddr, SocketAddr, SocketAddrV4};
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use clap::Parser;
use cookie::{Cookie, Expiration, SameSite};
use http::HeaderMap;
use imbl::OrdMap;
use imbl_value::InternedString;
use include_dir::Dir;
use models::GatewayId;
use patch_db::PatchDb;
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{CallRemote, Context, Empty};
use rpc_toolkit::{CallRemote, Context, Empty, ParentHandler};
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use tokio::sync::broadcast::Sender;
use tracing::instrument;
use url::Url;
use crate::auth::{Sessions, check_password};
use crate::context::CliContext;
use crate::auth::Sessions;
use crate::context::config::ContextConfig;
use crate::middleware::auth::AuthContext;
use crate::middleware::signature::SignatureAuthContext;
use crate::context::{CliContext, RpcContext};
use crate::db::model::public::{NetworkInterfaceInfo, NetworkInterfaceType};
use crate::else_empty_dir;
use crate::middleware::auth::{Auth, AuthContext};
use crate::middleware::cors::Cors;
use crate::net::forward::PortForwardController;
use crate::net::gateway::NetworkInterfaceWatcher;
use crate::net::static_server::UiContext;
use crate::prelude::*;
use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations};
use crate::tunnel::TUNNEL_DEFAULT_PORT;
use crate::tunnel::TUNNEL_DEFAULT_LISTEN;
use crate::tunnel::api::tunnel_api;
use crate::tunnel::db::TunnelDatabase;
use crate::util::sync::SyncMutex;
use crate::tunnel::wg::WIREGUARD_INTERFACE_NAME;
use crate::util::Invoke;
use crate::util::collections::OrdMapIterMut;
use crate::util::io::read_file_to_string;
use crate::util::sync::{SyncMutex, Watch};
#[derive(Debug, Clone, Default, Deserialize, Serialize, Parser)]
#[serde(rename_all = "kebab-case")]
@@ -59,14 +72,14 @@ impl TunnelConfig {
pub struct TunnelContextSeed {
pub listen: SocketAddr,
pub addrs: BTreeSet<IpAddr>,
pub db: TypedPatchDb<TunnelDatabase>,
pub datadir: PathBuf,
pub rpc_continuations: RpcContinuations,
pub open_authed_continuations: OpenAuthedContinuations<Option<InternedString>>,
pub ephemeral_sessions: SyncMutex<Sessions>,
pub net_iface: NetworkInterfaceWatcher,
pub net_iface: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
pub forward: PortForwardController,
pub active_forwards: SyncMutex<BTreeMap<SocketAddrV4, Arc<()>>>,
pub shutdown: Sender<()>,
}
@@ -75,6 +88,7 @@ pub struct TunnelContext(Arc<TunnelContextSeed>);
impl TunnelContext {
#[instrument(skip_all)]
pub async fn init(config: &TunnelConfig) -> Result<Self, Error> {
Self::init_auth_cookie().await?;
let (shutdown, _) = tokio::sync::broadcast::channel(1);
let datadir = config
.datadir
@@ -90,19 +104,83 @@ impl TunnelContext {
|| async { Ok(Default::default()) },
)
.await?;
let listen = config.tunnel_listen.unwrap_or(SocketAddr::new(
Ipv6Addr::UNSPECIFIED.into(),
TUNNEL_DEFAULT_PORT,
));
let net_iface = NetworkInterfaceWatcher::new(async { OrdMap::new() }, []);
let forward = PortForwardController::new(net_iface.subscribe());
let listen = config.tunnel_listen.unwrap_or(TUNNEL_DEFAULT_LISTEN);
let ip_info = crate::net::utils::load_ip_info().await?;
let net_iface = db
.mutate(|db| {
db.as_gateways_mut().mutate(|g| {
for (_, v) in OrdMapIterMut::from(&mut *g) {
v.ip_info = None;
}
for (id, info) in ip_info {
if id.as_str() != WIREGUARD_INTERFACE_NAME {
g.entry(id).or_default().ip_info = Some(Arc::new(info));
}
}
Ok(g.clone())
})
})
.await
.result?;
let net_iface = Watch::new(net_iface);
let forward = PortForwardController::new();
Command::new("sysctl")
.arg("-w")
.arg("net.ipv4.ip_forward=1")
.invoke(ErrorKind::Network)
.await?;
for iface in net_iface.peek(|i| {
i.iter()
.filter(|(_, info)| {
info.ip_info.as_ref().map_or(false, |i| {
i.device_type != Some(NetworkInterfaceType::Loopback)
})
})
.map(|(name, _)| name)
.filter(|id| id.as_str() != WIREGUARD_INTERFACE_NAME)
.cloned()
.collect::<Vec<_>>()
}) {
if Command::new("iptables")
.arg("-t")
.arg("nat")
.arg("-C")
.arg("POSTROUTING")
.arg("-o")
.arg(iface.as_str())
.arg("-j")
.arg("MASQUERADE")
.invoke(ErrorKind::Network)
.await
.is_err()
{
tracing::info!("Adding masquerade rule for interface {}", iface);
Command::new("iptables")
.arg("-t")
.arg("nat")
.arg("-A")
.arg("POSTROUTING")
.arg("-o")
.arg(iface.as_str())
.arg("-j")
.arg("MASQUERADE")
.invoke(ErrorKind::Network)
.await
.log_err();
}
}
let peek = db.peek().await;
peek.as_wg().de()?.sync().await?;
let mut active_forwards = BTreeMap::new();
for (from, to) in peek.as_port_forwards().de()?.0 {
active_forwards.insert(from, forward.add_forward(from, to).await?);
}
Ok(Self(Arc::new(TunnelContextSeed {
listen,
addrs: crate::net::utils::all_socket_addrs_for(listen.port())
.await?
.into_iter()
.map(|(_, a)| a.ip())
.collect(),
db,
datadir,
rpc_continuations: RpcContinuations::new(),
@@ -110,6 +188,7 @@ impl TunnelContext {
ephemeral_sessions: SyncMutex::new(Sessions::new()),
net_iface,
forward,
active_forwards: SyncMutex::new(active_forwards),
shutdown,
})))
}
@@ -120,6 +199,12 @@ impl AsRef<RpcContinuations> for TunnelContext {
}
}
impl AsRef<OpenAuthedContinuations<Option<InternedString>>> for TunnelContext {
fn as_ref(&self) -> &OpenAuthedContinuations<Option<InternedString>> {
&self.open_authed_continuations
}
}
impl Context for TunnelContext {}
impl Deref for TunnelContext {
type Target = TunnelContextSeed;
@@ -133,66 +218,6 @@ pub struct TunnelAddrParams {
pub tunnel: IpAddr,
}
impl SignatureAuthContext for TunnelContext {
type Database = TunnelDatabase;
type AdditionalMetadata = ();
type CheckPubkeyRes = ();
fn db(&self) -> &TypedPatchDb<Self::Database> {
&self.db
}
async fn sig_context(
&self,
) -> impl IntoIterator<Item = Result<impl AsRef<str> + Send, Error>> + Send {
self.addrs
.iter()
.filter(|a| !match a {
IpAddr::V4(a) => a.is_loopback() || a.is_unspecified(),
IpAddr::V6(a) => a.is_loopback() || a.is_unspecified(),
})
.map(|a| InternedString::from_display(&a))
.map(Ok)
}
fn check_pubkey(
db: &Model<Self::Database>,
pubkey: Option<&crate::sign::AnyVerifyingKey>,
_: Self::AdditionalMetadata,
) -> Result<Self::CheckPubkeyRes, Error> {
if let Some(pubkey) = pubkey {
if db.as_auth_pubkeys().de()?.contains(pubkey) {
return Ok(());
}
}
Err(Error::new(
eyre!("Developer Key is not authorized"),
ErrorKind::IncorrectPassword,
))
}
async fn post_auth_hook(
&self,
_: Self::CheckPubkeyRes,
_: &rpc_toolkit::RpcRequest,
) -> Result<(), Error> {
Ok(())
}
}
impl AuthContext for TunnelContext {
const LOCAL_AUTH_COOKIE_PATH: &str = "/run/start-tunnel/rpc.authcookie";
const LOCAL_AUTH_COOKIE_OWNERSHIP: &str = "root:root";
fn access_sessions(db: &mut Model<Self::Database>) -> &mut Model<crate::auth::Sessions> {
db.as_sessions_mut()
}
fn ephemeral_sessions(&self) -> &SyncMutex<Sessions> {
&self.ephemeral_sessions
}
fn open_authed_continuations(&self) -> &OpenAuthedContinuations<Option<InternedString>> {
&self.open_authed_continuations
}
fn check_password(db: &Model<Self::Database>, password: &str) -> Result<(), Error> {
check_password(&db.as_password().de()?, password)
}
}
impl CallRemote<TunnelContext> for CliContext {
async fn call_remote(
&self,
@@ -200,25 +225,97 @@ impl CallRemote<TunnelContext> for CliContext {
params: Value,
_: Empty,
) -> Result<Value, RpcError> {
let tunnel_addr = if let Some(addr) = self.tunnel_addr {
addr
let (tunnel_addr, addr_from_config) = if let Some(addr) = self.tunnel_addr {
(addr, true)
} else if let Some(addr) = self.tunnel_listen {
addr
(addr, true)
} else {
(TUNNEL_DEFAULT_LISTEN, false)
};
let local =
if let Ok(local) = read_file_to_string(TunnelContext::LOCAL_AUTH_COOKIE_PATH).await {
self.cookie_store
.lock()
.unwrap()
.insert_raw(
&Cookie::build(("local", local))
.domain(&tunnel_addr.ip().to_string())
.expires(Expiration::Session)
.same_site(SameSite::Strict)
.build(),
&format!("http://{tunnel_addr}").parse()?,
)
.with_kind(crate::ErrorKind::Network)?;
true
} else {
false
};
let (url, sig_ctx) = if local && tunnel_addr.ip().is_loopback() {
(format!("http://{tunnel_addr}/rpc/v0").parse()?, None)
} else if addr_from_config {
(
format!("https://{tunnel_addr}/rpc/v0").parse()?,
Some(InternedString::from_display(&tunnel_addr.ip())),
)
} else {
return Err(Error::new(eyre!("`--tunnel` required"), ErrorKind::InvalidRequest).into());
};
let sig_addr = self.tunnel_listen.unwrap_or(tunnel_addr);
let url = format!("https://{tunnel_addr}").parse()?;
method = method.strip_prefix("tunnel.").unwrap_or(method);
crate::middleware::signature::call_remote(
self,
url,
&InternedString::from_display(&sig_addr.ip()),
HeaderMap::new(),
sig_ctx.as_deref(),
method,
params,
)
.await
}
}
#[derive(Debug, Deserialize, Serialize, Parser)]
pub struct TunnelUrlParams {
pub tunnel: Url,
}
impl CallRemote<TunnelContext, TunnelUrlParams> for RpcContext {
async fn call_remote(
&self,
mut method: &str,
params: Value,
TunnelUrlParams { tunnel }: TunnelUrlParams,
) -> Result<Value, RpcError> {
let url = tunnel.join("rpc/v0")?;
method = method.strip_prefix("tunnel.").unwrap_or(method);
let sig_ctx = url.host_str().map(InternedString::from_display);
crate::middleware::signature::call_remote(
self,
url,
HeaderMap::new(),
sig_ctx.as_deref(),
method,
params,
)
.await
}
}
impl UiContext for TunnelContext {
const UI_DIR: &'static include_dir::Dir<'static> = &else_empty_dir!(
feature = "tunnel" =>
include_dir::include_dir!("$CARGO_MANIFEST_DIR/../../web/dist/static/start-tunnel")
);
fn api() -> ParentHandler<Self> {
tracing::info!("loading tunnel api...");
tunnel_api()
}
fn middleware(server: rpc_toolkit::Server<Self>) -> rpc_toolkit::HttpServer<Self> {
server.middleware(Cors::new()).middleware(Auth::new())
}
}

View File

@@ -1,11 +1,14 @@
use std::collections::{BTreeMap, HashSet};
use std::net::{Ipv4Addr, SocketAddrV4};
use std::collections::BTreeMap;
use std::net::SocketAddrV4;
use std::path::PathBuf;
use std::time::Duration;
use axum::extract::ws;
use clap::Parser;
use imbl::{HashMap, OrdMap};
use imbl_value::InternedString;
use ipnet::Ipv4Net;
use itertools::Itertools;
use models::GatewayId;
use patch_db::Dump;
use patch_db::json_ptr::{JsonPointer, ROOT};
use rpc_toolkit::yajrc::RpcError;
@@ -16,21 +19,48 @@ use ts_rs::TS;
use crate::auth::Sessions;
use crate::context::CliContext;
use crate::db::model::public::NetworkInterfaceInfo;
use crate::prelude::*;
use crate::rpc_continuations::{Guid, RpcContinuation};
use crate::sign::AnyVerifyingKey;
use crate::tunnel::auth::SignerInfo;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::web::WebserverInfo;
use crate::tunnel::wg::WgServer;
use crate::util::net::WebSocketExt;
use crate::util::serde::{HandlerExtSerde, apply_expr};
#[derive(Default, Deserialize, Serialize, HasModel)]
#[derive(Default, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct TunnelDatabase {
pub webserver: WebserverInfo,
pub sessions: Sessions,
pub password: String,
pub auth_pubkeys: HashSet<AnyVerifyingKey>,
pub password: Option<String>,
#[ts(as = "std::collections::HashMap::<AnyVerifyingKey, SignerInfo>")]
pub auth_pubkeys: HashMap<AnyVerifyingKey, SignerInfo>,
#[ts(as = "std::collections::BTreeMap::<AnyVerifyingKey, SignerInfo>")]
pub gateways: OrdMap<GatewayId, NetworkInterfaceInfo>,
pub wg: WgServer,
pub port_forwards: BTreeMap<SocketAddrV4, SocketAddrV4>,
pub port_forwards: PortForwards,
}
#[test]
fn export_bindings_tunnel_db() {
TunnelDatabase::export_all_to("bindings/tunnel").unwrap();
}
#[derive(Clone, Debug, Default, Deserialize, Serialize, TS)]
pub struct PortForwards(pub BTreeMap<SocketAddrV4, SocketAddrV4>);
impl Map for PortForwards {
type Key = SocketAddrV4;
type Value = SocketAddrV4;
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
Self::key_string(key)
}
fn key_string(key: &Self::Key) -> Result<InternedString, Error> {
Ok(InternedString::from_display(key))
}
}
pub fn db_api<C: Context>() -> ParentHandler<C> {
@@ -47,6 +77,12 @@ pub fn db_api<C: Context>() -> ParentHandler<C> {
.with_metadata("admin", Value::Bool(true))
.no_cli(),
)
.subcommand(
"subscribe",
from_fn_async(subscribe)
.with_metadata("get_session", Value::Bool(true))
.no_cli(),
)
.subcommand(
"apply",
from_fn_async(cli_apply)
@@ -195,3 +231,75 @@ pub async fn apply(ctx: TunnelContext, ApplyParams { expr, .. }: ApplyParams) ->
.await
.result
}
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct SubscribeParams {
#[ts(type = "string | null")]
pointer: Option<JsonPointer>,
#[ts(skip)]
#[serde(rename = "__Auth_session")]
session: Option<InternedString>,
}
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct SubscribeRes {
#[ts(type = "{ id: number; value: unknown }")]
pub dump: Dump,
pub guid: Guid,
}
pub async fn subscribe(
ctx: TunnelContext,
SubscribeParams { pointer, session }: SubscribeParams,
) -> Result<SubscribeRes, Error> {
let (dump, mut sub) = ctx
.db
.dump_and_sub(pointer.unwrap_or_else(|| ROOT.to_owned()))
.await;
let guid = Guid::new();
ctx.rpc_continuations
.add(
guid.clone(),
RpcContinuation::ws_authed(
&ctx,
session,
|mut ws| async move {
if let Err(e) = async {
loop {
tokio::select! {
rev = sub.recv() => {
if let Some(rev) = rev {
ws.send(ws::Message::Text(
serde_json::to_string(&rev)
.with_kind(ErrorKind::Serialization)?
.into(),
))
.await
.with_kind(ErrorKind::Network)?;
} else {
return ws.normal_close("complete").await;
}
}
msg = ws.recv() => {
if msg.transpose().with_kind(ErrorKind::Network)?.is_none() {
return Ok(())
}
}
}
}
}
.await
{
tracing::error!("Error in db websocket: {e}");
tracing::debug!("{e:?}");
}
},
Duration::from_secs(30),
),
)
.await;
Ok(SubscribeRes { dump, guid })
}

View File

@@ -1 +0,0 @@
use crate::prelude::*;

View File

@@ -1,82 +1,23 @@
use axum::Router;
use futures::future::ready;
use rpc_toolkit::{Context, HandlerExt, ParentHandler, Server, from_fn_async};
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use crate::context::CliContext;
use crate::middleware::auth::Auth;
use crate::middleware::cors::Cors;
use crate::net::static_server::{bad_request, not_found, server_error};
use crate::net::web_server::{Accept, WebServer};
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use axum::Router;
use crate::net::static_server::ui_router;
use crate::tunnel::context::TunnelContext;
pub mod api;
pub mod auth;
pub mod context;
pub mod db;
pub mod forward;
pub mod web;
pub mod wg;
pub const TUNNEL_DEFAULT_PORT: u16 = 5960;
pub const TUNNEL_DEFAULT_LISTEN: SocketAddr = SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(127, 0, 59, 60),
TUNNEL_DEFAULT_PORT,
));
pub fn tunnel_router(ctx: TunnelContext) -> Router {
use axum::extract as x;
use axum::routing::{any, get};
Router::new()
.route("/rpc/{*path}", {
let ctx = ctx.clone();
any(
Server::new(move || ready(Ok(ctx.clone())), api::tunnel_api())
.middleware(Cors::new())
.middleware(Auth::new())
)
})
.route(
"/ws/rpc/{*path}",
get({
let ctx = ctx.clone();
move |x::Path(path): x::Path<String>,
ws: axum::extract::ws::WebSocketUpgrade| async move {
match Guid::from(&path) {
None => {
tracing::debug!("No Guid Path");
bad_request()
}
Some(guid) => match ctx.rpc_continuations.get_ws_handler(&guid).await {
Some(cont) => ws.on_upgrade(cont),
_ => not_found(),
},
}
}
}),
)
.route(
"/rest/rpc/{*path}",
any({
let ctx = ctx.clone();
move |request: x::Request| async move {
let path = request
.uri()
.path()
.strip_prefix("/rest/rpc/")
.unwrap_or_default();
match Guid::from(&path) {
None => {
tracing::debug!("No Guid Path");
bad_request()
}
Some(guid) => match ctx.rpc_continuations.get_rest_handler(&guid).await {
None => not_found(),
Some(cont) => cont(request).await.unwrap_or_else(server_error),
},
}
}
}),
)
}
impl<A: Accept + Send + Sync + 'static> WebServer<A> {
pub fn serve_tunnel(&mut self, ctx: TunnelContext) {
self.serve_router(tunnel_router(ctx))
}
ui_router(ctx)
}

View File

@@ -0,0 +1,688 @@
use std::collections::VecDeque;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use clap::Parser;
use hickory_client::proto::rr::rdata::cert;
use imbl_value::{InternedString, json};
use itertools::Itertools;
use openssl::pkey::{PKey, Private};
use openssl::x509::X509;
use rpc_toolkit::{
Context, Empty, HandlerArgs, HandlerExt, ParentHandler, from_fn_async, from_fn_async_local,
};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio_rustls::rustls::ServerConfig;
use tokio_rustls::rustls::crypto::CryptoProvider;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use tokio_rustls::rustls::server::ClientHello;
use ts_rs::TS;
use crate::context::CliContext;
use crate::net::ssl::SANInfo;
use crate::net::tls::TlsHandler;
use crate::net::web_server::Accept;
use crate::prelude::*;
use crate::tunnel::auth::SetPasswordParams;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::db::TunnelDatabase;
use crate::util::serde::{HandlerExtSerde, Pem, display_serializable};
use crate::util::tui::{choose, choose_custom_display, parse_as, prompt, prompt_multiline};
#[derive(Debug, Default, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct WebserverInfo {
pub enabled: bool,
pub listen: Option<SocketAddr>,
pub certificate: Option<TunnelCertData>,
}
#[derive(Debug, Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct TunnelCertData {
pub key: Pem<PKey<Private>>,
pub cert: Pem<Vec<X509>>,
}
#[derive(Clone)]
pub struct TunnelCertHandler {
pub db: TypedPatchDb<TunnelDatabase>,
pub crypto_provider: Arc<CryptoProvider>,
}
impl<'a, A> TlsHandler<'a, A> for TunnelCertHandler
where
A: Accept + 'a,
<A as Accept>::Metadata: Send + Sync,
{
async fn get_config(
&'a mut self,
_: &'a ClientHello<'a>,
_: &'a <A as Accept>::Metadata,
) -> Option<ServerConfig> {
let cert_info = self
.db
.peek()
.await
.as_webserver()
.as_certificate()
.de()
.log_err()??;
let cert_chain: Vec<_> = cert_info
.cert
.0
.iter()
.map(|c| Ok::<_, Error>(CertificateDer::from(c.to_der()?)))
.collect::<Result<_, _>>()
.log_err()?;
let cert_key = cert_info.key.0.private_key_to_pkcs8().log_err()?;
Some(
ServerConfig::builder_with_provider(self.crypto_provider.clone())
.with_safe_default_protocol_versions()
.log_err()?
.with_no_client_auth()
.with_single_cert(
cert_chain,
PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert_key)),
)
.log_err()?,
)
}
}
pub fn web_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand(
"init",
from_fn_async_local(init_web)
.no_display()
.with_about("Initialize the webserver"),
)
.subcommand(
"set-listen",
from_fn_async(set_listen)
.no_display()
.with_about("Set the listen address for the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"get-listen",
from_fn_async(get_listen)
.with_display_serializable()
.with_about("Get the listen address for the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"get-available-ips",
from_fn_async(get_available_ips)
.with_display_serializable()
.with_about("Get available IP addresses to bind to")
.with_call_remote::<CliContext>(),
)
.subcommand(
"import-certificate",
from_fn_async(import_certificate_rpc).no_cli(),
)
.subcommand(
"import-certificate",
from_fn_async_local(import_certificate_cli)
.no_display()
.with_about("Import a certificate to use for the webserver"),
)
.subcommand(
"generate-certificate",
from_fn_async(generate_certificate)
.with_about("Generate a self signed certificaet to use for the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"get-certificate",
from_fn_async(get_certificate)
.with_display_serializable()
.with_custom_display_fn(|HandlerArgs { params, .. }, res| {
if let Some(format) = params.format {
return display_serializable(format, res);
}
if let Some(res) = res {
println!("{res}");
}
Ok(())
})
.with_about("Get the certificate for the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"enable",
from_fn_async(enable_web)
.with_about("Enable the webserver")
.no_display()
.with_call_remote::<CliContext>(),
)
.subcommand(
"disable",
from_fn_async(disable_web)
.no_display()
.with_about("Disable the webserver")
.with_call_remote::<CliContext>(),
)
.subcommand(
"reset",
from_fn_async(reset_web)
.no_display()
.with_about("Reset the webserver")
.with_call_remote::<CliContext>(),
)
}
pub async fn import_certificate_rpc(
ctx: TunnelContext,
cert_data: TunnelCertData,
) -> Result<(), Error> {
ctx.db
.mutate(|db| {
db.as_webserver_mut()
.as_certificate_mut()
.ser(&Some(cert_data))
})
.await
.result?;
Ok(())
}
pub async fn import_certificate_cli(
HandlerArgs {
context,
parent_method,
method,
..
}: HandlerArgs<CliContext>,
) -> Result<(), Error> {
let mut key_string = String::new();
let key: Pem<PKey<Private>> =
prompt_multiline("Please paste in your PEM encoded private key: ", |line| {
key_string.push_str(&line);
key_string.push_str("\n");
if line.trim().starts_with("-----END") {
return key_string.parse().map(Some).map_err(|e| {
key_string.truncate(0);
e
});
}
Ok(None)
})
.await?;
let mut chain = Vec::<X509>::new();
let mut cert_string = String::new();
prompt_multiline(
concat!(
"Please paste in your PEM encoded certificate",
" (or certificate chain):"
),
|line| {
cert_string.push_str(&line);
cert_string.push_str("\n");
if line.trim().starts_with("-----END") {
let cert = cert_string.parse::<Pem<X509>>();
cert_string.truncate(0);
let cert = cert?;
let pubkey = cert.0.public_key()?;
if chain.is_empty() {
if !key.public_eq(&pubkey) {
return Err(Error::new(
eyre!("Certificate does not match key!"),
ErrorKind::InvalidSignature,
));
}
}
if let Some(prev) = chain.last() {
if !prev.verify(&pubkey)? {
return Err(Error::new(
eyre!(concat!(
"Invalid Fullchain: ",
"Previous cert was not signed by this certificate's key"
)),
ErrorKind::InvalidSignature,
));
}
}
let is_root = cert.0.verify(&pubkey)?;
chain.push(cert.0);
if is_root { Ok(Some(())) } else { Ok(None) }
} else {
Ok(None)
}
},
)
.await?;
context
.call_remote::<TunnelContext>(
&parent_method.iter().chain(method.iter()).join("."),
to_value(&TunnelCertData {
key,
cert: Pem(chain),
})?,
)
.await?;
Ok(())
}
#[derive(Debug, Deserialize, Serialize, Parser)]
pub struct GenerateCertParams {
#[arg(help = "Subject Alternative Name(s)")]
pub subject: Vec<InternedString>,
}
pub async fn generate_certificate(
ctx: TunnelContext,
GenerateCertParams { subject }: GenerateCertParams,
) -> Result<Pem<X509>, Error> {
let saninfo = SANInfo::new(&subject.into_iter().collect());
let key = crate::net::ssl::generate_key()?;
let cert = crate::net::ssl::make_self_signed((&key, &saninfo))?;
ctx.db
.mutate(|db| {
db.as_webserver_mut()
.as_certificate_mut()
.ser(&Some(TunnelCertData {
key: Pem(key),
cert: Pem(vec![cert.clone()]),
}))
})
.await
.result?;
Ok(Pem(cert))
}
pub async fn get_certificate(ctx: TunnelContext) -> Result<Option<Pem<Vec<X509>>>, Error> {
ctx.db
.peek()
.await
.as_webserver()
.as_certificate()
.de()?
.map(|cert_data| Ok(cert_data.cert))
.transpose()
}
#[derive(Debug, Deserialize, Serialize, Parser)]
#[serde(rename_all = "camelCase")]
pub struct SetListenParams {
pub listen: SocketAddr,
}
pub async fn set_listen(
ctx: TunnelContext,
SetListenParams { listen }: SetListenParams,
) -> Result<(), Error> {
// Validate that the address is available to bind
tokio::net::TcpListener::bind(listen)
.await
.with_kind(ErrorKind::Network)
.with_ctx(|_| {
(
ErrorKind::Network,
format!("{} is not available to bind to", listen),
)
})?;
ctx.db
.mutate(|db| {
db.as_webserver_mut().as_listen_mut().ser(&Some(listen))?;
Ok(())
})
.await
.result
}
pub async fn get_listen(ctx: TunnelContext) -> Result<Option<SocketAddr>, Error> {
ctx.db.peek().await.as_webserver().as_listen().de()
}
pub async fn get_available_ips(ctx: TunnelContext) -> Result<Vec<IpAddr>, Error> {
let ips = ctx.net_iface.peek(|interfaces| {
interfaces
.values()
.flat_map(|info| {
info.ip_info
.iter()
.flat_map(|ip_info| ip_info.subnets.iter().map(|subnet| subnet.addr()))
})
.collect::<Vec<IpAddr>>()
});
Ok(ips)
}
pub async fn enable_web(ctx: TunnelContext) -> Result<(), Error> {
ctx.db
.mutate(|db| {
if db.as_webserver().as_listen().transpose_ref().is_none() {
return Err(Error::new(
eyre!("Listen is not set"),
ErrorKind::ParseNetAddress,
));
}
if db.as_webserver().as_certificate().transpose_ref().is_none() {
return Err(Error::new(
eyre!("Certificate is not set"),
ErrorKind::OpenSsl,
));
}
if db.as_password().transpose_ref().is_none() {
return Err(Error::new(
eyre!("Password is not set"),
ErrorKind::Authorization,
));
};
db.as_webserver_mut().as_enabled_mut().ser(&true)
})
.await
.result
}
pub async fn disable_web(ctx: TunnelContext) -> Result<(), Error> {
ctx.db
.mutate(|db| db.as_webserver_mut().as_enabled_mut().ser(&false))
.await
.result
}
pub async fn reset_web(ctx: TunnelContext) -> Result<(), Error> {
ctx.db
.mutate(|db| {
db.as_webserver_mut().as_enabled_mut().ser(&false)?;
db.as_webserver_mut().as_listen_mut().ser(&None)?;
db.as_webserver_mut().as_certificate_mut().ser(&None)?;
db.as_password_mut().ser(&None)?;
Ok(())
})
.await
.result
}
fn is_valid_domain(domain: &str) -> bool {
if domain.is_empty() || domain.len() > 253 || domain.starts_with('.') || domain.ends_with('.') {
return false;
}
let labels: Vec<&str> = domain.split('.').collect();
for label in labels {
if label.is_empty() || label.len() > 63 {
return false;
}
if !label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-') {
return false;
}
if label.chars().next().map_or(true, |c| c == '-')
|| label.chars().next_back().map_or(true, |c| c == '-')
{
return false;
}
}
true
}
pub async fn init_web(ctx: CliContext) -> Result<(), Error> {
let mut password = None;
loop {
match ctx
.call_remote::<TunnelContext>("web.enable", json!({}))
.await
{
Ok(_) => {
let listen = from_value::<SocketAddr>(
ctx.call_remote::<TunnelContext>("web.get-listen", json!({}))
.await?,
)?;
println!("✅ Success! ✅");
println!(
"The webserver is running. Below is your URL{} and SSL certificate.",
if password.is_some() {
", password,"
} else {
""
}
);
println!();
println!("🌐 URL");
println!("https://{listen}");
if listen.ip().is_unspecified() {
println!(concat!(
"Note: this is the unspecified address. ",
"This means you can use any IP address available to this device to connect. ",
"Using the above address as-is will only work from this device."
));
} else if listen.ip().is_loopback() {
println!(concat!(
"Note: this is a loopback address. ",
"This is only recommended if you are planning to run a proxy in front of the web ui. ",
"Using the above address as-is will only work from this device."
));
}
println!();
if let Some(password) = password {
println!("🔒 Password");
println!("{password}");
println!();
println!(concat!(
"If you lose or forget your password, you can reset it using the command: ",
"start-tunnel auth reset-password"
));
} else {
println!(concat!(
"Your password was set up previously. ",
"If you don't remember it, you can reset it using the command: ",
"start-tunnel auth reset-password"
));
}
println!();
let cert = from_value::<Pem<Vec<X509>>>(
ctx.call_remote::<TunnelContext>("web.get-certificate", json!({}))
.await?,
)?;
println!("📝 SSL Certificate:");
print!("{cert}");
println!(concat!(
"If you haven't already, ",
"trust the certificate in your system keychain and/or browser."
));
return Ok(());
}
Err(e) if e.kind == ErrorKind::ParseNetAddress => {
println!("Select the IP address at which to host the web interface:");
let mut suggested_addrs = from_value::<Vec<IpAddr>>(
ctx.call_remote::<TunnelContext>("web.get-available-ips", json!({}))
.await?,
)?;
suggested_addrs.sort_by_cached_key(|a| match a {
IpAddr::V4(a) => {
if a.is_loopback() {
3
} else if a.is_private() {
2
} else {
0
}
}
IpAddr::V6(a) => {
if a.is_loopback() {
5
} else if a.is_unicast_link_local() {
4
} else {
1
}
}
});
let ip = if suggested_addrs.is_empty() {
prompt("Listen Address: ", parse_as::<IpAddr>("IP Address"), None).await?
} else if suggested_addrs.len() > 16 {
prompt(
&format!("Listen Address [{}]: ", suggested_addrs[0]),
parse_as::<IpAddr>("IP Address"),
Some(suggested_addrs[0]),
)
.await?
} else {
*choose_custom_display("Listen Address:", &suggested_addrs, |a| match a {
a if a.is_loopback() => {
format!("{a} (Loopback Address: only use if planning to proxy traffic)")
}
IpAddr::V4(a) if a.is_private() => {
format!("{a} (Private Address: only available from Local Area Network)")
}
IpAddr::V6(a) if a.is_unicast_link_local() => {
format!(
"[{a}] (Private Address: only available from Local Area Network)"
)
}
IpAddr::V6(a) => format!("[{a}]"),
a => a.to_string(),
})
.await?
};
println!(concat!(
"Enter the port at which to host the web interface. ",
"The recommended default is 8443. ",
"If you change the default, choose an uncommon port to avoid conflicts: "
));
let port = prompt("Port [8443]: ", parse_as::<u16>("port"), Some(8443)).await?;
let listen = SocketAddr::new(ip, port);
ctx.call_remote::<TunnelContext>(
"web.set-listen",
to_value(&SetListenParams { listen })?,
)
.await?;
println!();
}
Err(e) if e.kind == ErrorKind::OpenSsl => {
enum Choice {
Generate,
Provide,
}
impl std::fmt::Display for Choice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Generate => write!(f, "Generate a Self Signed Certificate"),
Self::Provide => write!(f, "Provide your own certificate and key"),
}
}
}
let options = vec![Choice::Generate, Choice::Provide];
let choice = choose(
concat!(
"Select whether to autogenerate a self-signed SSL certificate ",
"or provide your own certificate and key:"
),
&options,
)
.await?;
match choice {
Choice::Generate => {
let listen = from_value::<Option<SocketAddr>>(
ctx.call_remote::<TunnelContext>("web.get-listen", json!({}))
.await?,
)?
.filter(|a| !a.ip().is_unspecified());
let default_prompt = if let Some(listen) = listen {
format!("Subject Alternative Name(s) [{}]: ", listen.ip())
} else {
"Subject Alternative Name(s): ".to_string()
};
println!(
"List all IP addresses and domains for which to sign the certificate, separated by commas."
);
let san_info = prompt(
&default_prompt,
|s| {
s.split(",")
.map(|s| {
let s = s.trim();
if let Ok(ip) = s.parse::<IpAddr>() {
Ok(InternedString::from_display(&ip))
} else if is_valid_domain(s) {
Ok(s.into())
} else {
Err(format!("{s} is not a valid ip address or domain"))
}
})
.collect()
},
listen.map(|l| vec![InternedString::from_display(&l.ip())]),
)
.await?;
ctx.call_remote::<TunnelContext>(
"web.generate-certificate",
to_value(&GenerateCertParams { subject: san_info })?,
)
.await?;
}
Choice::Provide => {
import_certificate_cli(HandlerArgs {
context: ctx.clone(),
parent_method: vec!["web", "import-certificate"].into(),
method: VecDeque::new(),
params: Empty {},
inherited_params: Empty {},
raw_params: json!({}),
})
.await?;
}
}
println!();
}
Err(e) if e.kind == ErrorKind::Authorization => {
println!("Generating a random password...");
let params = SetPasswordParams {
password: base32::encode(
base32::Alphabet::Rfc4648Lower { padding: false },
&rand::random::<[u8; 16]>(),
),
};
ctx.call_remote::<TunnelContext>("auth.set-password", to_value(&params)?)
.await?;
password = Some(params.password);
println!();
}
Err(e) => return Err(e.into()),
}
}
}

View File

@@ -1,19 +1,22 @@
use std::collections::BTreeMap;
use std::net::{Ipv4Addr, SocketAddrV4};
use std::net::{Ipv4Addr, SocketAddr};
use ed25519_dalek::{SigningKey, VerifyingKey};
use imbl_value::InternedString;
use ipnet::Ipv4Net;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use ts_rs::TS;
use x25519_dalek::{PublicKey, StaticSecret};
use crate::prelude::*;
use crate::util::Invoke;
use crate::util::io::write_file_atomic;
use crate::util::serde::Base64;
#[derive(Deserialize, Serialize, HasModel)]
pub const WIREGUARD_INTERFACE_NAME: &str = "wg-start-tunnel";
#[derive(Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct WgServer {
@@ -37,7 +40,7 @@ impl WgServer {
pub async fn sync(&self) -> Result<(), Error> {
Command::new("wg-quick")
.arg("down")
.arg("wg0")
.arg(WIREGUARD_INTERFACE_NAME)
.invoke(ErrorKind::Network)
.await
.or_else(|e| {
@@ -49,21 +52,23 @@ impl WgServer {
}
})?;
write_file_atomic(
"/etc/wireguard/wg0.conf",
const_format::formatcp!("/etc/wireguard/{WIREGUARD_INTERFACE_NAME}.conf"),
self.server_config().to_string().as_bytes(),
)
.await?;
Command::new("wg-quick")
.arg("up")
.arg("wg0")
.arg(WIREGUARD_INTERFACE_NAME)
.invoke(ErrorKind::Network)
.await?;
Ok(())
}
}
#[derive(Default, Deserialize, Serialize)]
pub struct WgSubnetMap(pub BTreeMap<Ipv4Net, WgSubnetConfig>);
#[derive(Default, Deserialize, Serialize, TS)]
pub struct WgSubnetMap(
#[ts(as = "BTreeMap::<String, WgSubnetConfig>")] pub BTreeMap<Ipv4Net, WgSubnetConfig>,
);
impl Map for WgSubnetMap {
type Key = Ipv4Net;
type Value = WgSubnetConfig;
@@ -75,35 +80,41 @@ impl Map for WgSubnetMap {
}
}
#[derive(Default, Deserialize, Serialize, HasModel)]
#[derive(Default, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct WgSubnetConfig {
pub default_forward_target: Option<Ipv4Addr>,
pub clients: BTreeMap<Ipv4Addr, WgConfig>,
pub name: InternedString,
pub clients: WgSubnetClients,
}
impl WgSubnetConfig {
pub fn new() -> Self {
Self::default()
}
pub fn add_client<'a>(
&'a mut self,
subnet: Ipv4Net,
) -> Result<(Ipv4Addr, &'a WgConfig), Error> {
let addr = subnet
.hosts()
.find(|a| !self.clients.contains_key(a))
.ok_or_else(|| Error::new(eyre!("subnet exhausted"), ErrorKind::Network))?;
let config = self.clients.entry(addr).or_insert(WgConfig::generate());
Ok((addr, config))
pub fn new(name: InternedString) -> Self {
Self {
name,
..Self::default()
}
}
}
pub struct WgKey(SigningKey);
#[derive(Default, Deserialize, Serialize, TS)]
pub struct WgSubnetClients(pub BTreeMap<Ipv4Addr, WgConfig>);
impl Map for WgSubnetClients {
type Key = Ipv4Addr;
type Value = WgConfig;
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
Self::key_string(key)
}
fn key_string(key: &Self::Key) -> Result<InternedString, Error> {
Ok(InternedString::from_display(key))
}
}
#[derive(Clone)]
pub struct WgKey(StaticSecret);
impl WgKey {
pub fn generate() -> Self {
Self(SigningKey::generate(
&mut ssh_key::rand_core::OsRng::default(),
Self(StaticSecret::random_from_rng(
ssh_key::rand_core::OsRng::default(),
))
}
}
@@ -113,33 +124,39 @@ impl AsRef<[u8]> for WgKey {
}
}
impl TryFrom<Vec<u8>> for WgKey {
type Error = ed25519_dalek::SignatureError;
type Error = Error;
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
Ok(Self(value.as_slice().try_into()?))
Ok(Self(
<[u8; 32]>::try_from(value)
.map_err(|_| Error::new(eyre!("invalid key length"), ErrorKind::Deserialization))?
.into(),
))
}
}
impl std::ops::Deref for WgKey {
type Target = SigningKey;
type Target = StaticSecret;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Base64<WgKey> {
pub fn verifying_key(&self) -> Base64<VerifyingKey> {
Base64(self.0.verifying_key())
pub fn verifying_key(&self) -> Base64<PublicKey> {
Base64((&*self.0).into())
}
}
#[derive(Deserialize, Serialize, HasModel)]
#[derive(Clone, Deserialize, Serialize, HasModel, TS)]
#[serde(rename_all = "camelCase")]
#[model = "Model<Self>"]
pub struct WgConfig {
pub name: InternedString,
pub key: Base64<WgKey>,
pub psk: Base64<[u8; 32]>,
}
impl WgConfig {
pub fn generate() -> Self {
pub fn generate(name: InternedString) -> Self {
Self {
name,
key: Base64(WgKey::generate()),
psk: Base64(rand::random()),
}
@@ -150,12 +167,12 @@ impl WgConfig {
client_addr: addr,
}
}
pub fn client_config<'a>(
&'a self,
pub fn client_config(
self,
addr: Ipv4Addr,
server_pubkey: Base64<VerifyingKey>,
server_addr: SocketAddrV4,
) -> ClientConfig<'a> {
server_pubkey: Base64<PublicKey>,
server_addr: SocketAddr,
) -> ClientConfig {
ClientConfig {
client_config: self,
client_addr: addr,
@@ -181,19 +198,33 @@ impl<'a> std::fmt::Display for ServerPeerConfig<'a> {
}
}
pub struct ClientConfig<'a> {
client_config: &'a WgConfig,
client_addr: Ipv4Addr,
server_pubkey: Base64<VerifyingKey>,
server_addr: SocketAddrV4,
fn deserialize_verifying_key<'de, D>(deserializer: D) -> Result<Base64<PublicKey>, D::Error>
where
D: serde::Deserializer<'de>,
{
Base64::<Vec<u8>>::deserialize(deserializer).and_then(|b| {
Ok(Base64(PublicKey::from(<[u8; 32]>::try_from(b.0).map_err(
|e: Vec<u8>| serde::de::Error::invalid_length(e.len(), &"a 32 byte base64 string"),
)?)))
})
}
impl<'a> std::fmt::Display for ClientConfig<'a> {
#[derive(Clone, Serialize, Deserialize)]
pub struct ClientConfig {
client_config: WgConfig,
client_addr: Ipv4Addr,
#[serde(deserialize_with = "deserialize_verifying_key")]
server_pubkey: Base64<PublicKey>,
server_addr: SocketAddr,
}
impl std::fmt::Display for ClientConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
include_str!("./client.conf.template"),
name = self.client_config.name,
privkey = self.client_config.key.to_padded_string(),
psk = self.client_config.psk,
psk = self.client_config.psk.to_padded_string(),
addr = self.client_addr,
server_pubkey = self.server_pubkey.to_padded_string(),
server_addr = self.server_addr,
@@ -212,7 +243,7 @@ impl<'a> std::fmt::Display for ServerConfig<'a> {
server_port = server.port,
server_privkey = server.key.to_padded_string(),
)?;
for (addr, peer) in server.subnets.0.values().flat_map(|s| &s.clients) {
for (addr, peer) in server.subnets.0.values().flat_map(|s| &s.clients.0) {
write!(f, "{}", peer.server_peer_config(*addr))?;
}
Ok(())

View File

@@ -8,6 +8,30 @@ pub use eq_map::EqMap;
pub use eq_set::EqSet;
use imbl::OrdMap;
pub fn ordmap_retain<K: Ord + Clone, V: Clone, F: FnMut(&K, &mut V) -> bool>(
map: &mut OrdMap<K, V>,
mut f: F,
) {
let mut prev = None;
loop {
let next = if let Some(k) = prev.take() {
map.range((Bound::Excluded(k), Bound::Unbounded)).next()
} else {
map.get_min().map(|(k, v)| (k, v))
};
let Some((k, _)) = next else {
break;
};
let k = k.clone(); // hate that I have to do this but whatev
let v = map.get_mut(&k).unwrap();
if !f(&k, v) {
map.remove(&k);
}
prev = Some(k);
}
}
pub struct OrdMapIterMut<'a, K: 'a, V: 'a> {
map: *mut OrdMap<K, V>,
prev: Option<&'a K>,

View File

@@ -1,10 +1,13 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use futures::future::{abortable, pending, BoxFuture, FusedFuture};
use axum::middleware::FromFn;
use futures::future::{BoxFuture, FusedFuture, abortable, pending};
use futures::stream::{AbortHandle, Abortable, BoxStream};
use futures::{Future, FutureExt, Stream, StreamExt};
use rpc_toolkit::from_fn_blocking;
use tokio::sync::watch;
use tokio::task::LocalSet;
use crate::prelude::*;
@@ -158,6 +161,31 @@ impl<'a> Until<'a> {
}
}
pub async fn make_send<F, Fut, T>(f: F) -> Result<T, Error>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<T, Error>> + 'static,
T: Send + 'static,
{
tokio::task::spawn_blocking(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let local = LocalSet::new();
local.block_on(&rt, async move { f().await })
})
.await
.map_err(|e| {
Error::new(
eyre!("Task running non-Send future panicked: {}", e),
ErrorKind::Unknown,
)
})?
}
#[tokio::test]
async fn test_cancellable() {
use std::sync::Arc;

View File

@@ -6,15 +6,15 @@ use std::os::unix::prelude::MetadataExt;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::str::FromStr;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::task::{Poll, Waker};
use std::time::Duration;
use bytes::{Buf, BytesMut};
use clap::builder::ValueParserFactory;
use futures::future::{BoxFuture, Fuse};
use futures::{AsyncSeek, FutureExt, Stream, StreamExt, TryStreamExt};
use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
use helpers::{AtomicFile, NonDetachingJoinHandle};
use inotify::{EventMask, EventStream, Inotify, WatchMask};
use models::FromStrParser;
@@ -22,7 +22,8 @@ use nix::unistd::{Gid, Uid};
use serde::{Deserialize, Serialize};
use tokio::fs::{File, OpenOptions};
use tokio::io::{
duplex, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf, WriteHalf,
AsyncRead, AsyncReadExt, AsyncSeek, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf, SeekFrom,
WriteHalf, duplex,
};
use tokio::net::TcpStream;
use tokio::sync::{Notify, OwnedMutexGuard};

View File

@@ -59,7 +59,7 @@ impl StartOSLogger {
fn base_subscriber(logfile: LogFile) -> impl Subscriber {
use tracing_error::ErrorLayer;
use tracing_subscriber::prelude::*;
use tracing_subscriber::{fmt, EnvFilter};
use tracing_subscriber::{EnvFilter, fmt};
let filter_layer = || {
EnvFilter::builder()

View File

@@ -14,10 +14,10 @@ use ::serde::{Deserialize, Serialize};
use async_trait::async_trait;
use color_eyre::eyre::{self, eyre};
use fd_lock_rs::FdLock;
use futures::future::BoxFuture;
use futures::FutureExt;
use helpers::canonicalize;
use futures::future::BoxFuture;
pub use helpers::NonDetachingJoinHandle;
use helpers::canonicalize;
use imbl_value::InternedString;
use lazy_static::lazy_static;
pub use models::VersionString;
@@ -25,7 +25,7 @@ use pin_project::pin_project;
use sha2::Digest;
use tokio::fs::File;
use tokio::io::{AsyncRead, AsyncReadExt, BufReader};
use tokio::sync::{oneshot, Mutex, OwnedMutexGuard, RwLock};
use tokio::sync::{Mutex, OwnedMutexGuard, RwLock, oneshot};
use tracing::instrument;
use ts_rs::TS;
use url::Url;
@@ -49,7 +49,9 @@ pub mod net;
pub mod rpc;
pub mod rpc_client;
pub mod serde;
//pub mod squashfs;
pub mod sync;
pub mod tui;
#[derive(Clone, Copy, Debug, ::serde::Deserialize, ::serde::Serialize)]
pub enum Never {}

View File

@@ -1055,7 +1055,11 @@ impl<T: TryFrom<Vec<u8>>> ValueParserFactory for Base64<T> {
Self::Parser::new()
}
}
impl<'de, T: TryFrom<Vec<u8>>> Deserialize<'de> for Base64<T> {
impl<'de, T> Deserialize<'de> for Base64<T>
where
Base64<T>: FromStr,
<Base64<T> as FromStr>::Err: std::fmt::Display,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
@@ -1195,6 +1199,21 @@ impl PemEncoding for X509 {
}
}
impl PemEncoding for Vec<X509> {
fn from_pem<E: serde::de::Error>(pem: &str) -> Result<Self, E> {
X509::stack_from_pem(pem.as_bytes()).map_err(E::custom)
}
fn to_pem<E: serde::ser::Error>(&self) -> Result<String, E> {
self.iter()
.map(|x| x.to_pem())
.try_fold(String::new(), |mut acc, x| {
acc.push_str(&x?);
acc.push_str("\n");
Ok(acc)
})
}
}
impl PemEncoding for PKey<Private> {
fn from_pem<E: serde::de::Error>(pem: &str) -> Result<Self, E> {
Self::private_key_from_pem(pem.as_bytes()).map_err(E::custom)

File diff suppressed because it is too large Load Diff

View File

@@ -5,8 +5,8 @@ use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Weak};
use std::task::{Poll, Waker};
use futures::stream::BoxStream;
use futures::Stream;
use futures::stream::BoxStream;
use crate::prelude::*;
@@ -266,6 +266,9 @@ impl<T> Watch<T> {
version: 0,
}
}
pub fn watcher_count(&self) -> usize {
Arc::strong_count(&self.shared)
}
#[cfg_attr(feature = "unstable", inline(never))]
pub fn poll_changed(&mut self, cx: &mut std::task::Context<'_>) -> Poll<()> {
self.shared.mutate(|shared| {

View File

@@ -0,0 +1,144 @@
use std::io::Write;
use std::str::FromStr;
use r3bl_tui::{DefaultIoDevices, ReadlineAsyncContext, ReadlineEvent};
use crate::prelude::*;
fn map_miette(m: miette::Error) -> Error {
Error::new(eyre!("{m}"), ErrorKind::Filesystem)
}
fn noninteractive_err() -> Error {
Error::new(
eyre!("Terminal must be in interactive mode for this wizard"),
ErrorKind::Filesystem,
)
}
pub fn parse_as<'a, T>(what: &'a str) -> impl Fn(&str) -> Result<T, String> + 'a
where
T: FromStr,
{
move |s| {
s.parse::<T>()
.map_err(|_| format!("Please enter a valid {what}."))
}
}
pub async fn prompt<T, E: std::fmt::Display, Parse: FnMut(&str) -> Result<T, E>>(
prompt: &str,
mut parse: Parse,
default: Option<T>,
) -> Result<T, Error> {
let mut rl_ctx = ReadlineAsyncContext::try_new(Some(prompt))
.await
.map_err(map_miette)?
.ok_or_else(noninteractive_err)?;
let res = loop {
match rl_ctx.read_line().await.map_err(map_miette)? {
ReadlineEvent::Line(l) => {
let l = l.trim();
if !l.is_empty() {
match parse(l) {
Ok(a) => break a,
Err(e) => {
writeln!(&mut rl_ctx.shared_writer, "{e}")?;
}
}
} else if let Some(default) = default {
break default;
}
}
ReadlineEvent::Eof | ReadlineEvent::Interrupted => {
return Err(Error::new(eyre!("Aborted"), ErrorKind::Cancelled));
}
_ => (),
}
};
rl_ctx.request_shutdown(None).await.map_err(map_miette)?;
rl_ctx.await_shutdown().await;
Ok(res)
}
pub async fn prompt_multiline<
T,
E: std::fmt::Display,
HandleLine: FnMut(String) -> Result<Option<T>, E>,
>(
prompt: &str,
mut handle_line: HandleLine,
) -> Result<T, Error> {
println!("{prompt}");
let mut rl_ctx = ReadlineAsyncContext::try_new(None::<&str>)
.await
.map_err(map_miette)?
.ok_or_else(noninteractive_err)?;
let res = loop {
match rl_ctx.read_line().await.map_err(map_miette)? {
ReadlineEvent::Line(l) => match handle_line(l) {
Ok(Some(a)) => break a,
Ok(None) => (),
Err(e) => writeln!(&mut rl_ctx.shared_writer, "{e}")?,
},
ReadlineEvent::Eof | ReadlineEvent::Interrupted => {
return Err(Error::new(eyre!("Aborted"), ErrorKind::Cancelled));
}
_ => (),
}
};
rl_ctx.request_shutdown(None).await.map_err(map_miette)?;
rl_ctx.await_shutdown().await;
Ok(res)
}
pub async fn choose_custom_display<'t, T: std::fmt::Display>(
prompt: &str,
choices: &'t [T],
mut display: impl FnMut(&T) -> String,
) -> Result<&'t T, Error> {
let mut io = DefaultIoDevices::default();
let style = r3bl_tui::readline_async::StyleSheet::default();
let string_choices = choices.into_iter().map(|c| display(c)).collect::<Vec<_>>();
let choice = r3bl_tui::readline_async::choose(
prompt,
string_choices.clone(),
None,
None,
r3bl_tui::HowToChoose::Single,
style,
(
&mut io.output_device,
&mut io.input_device,
io.maybe_shared_writer,
),
)
.await
.map_err(map_miette)?;
if choice.len() < 1 {
return Err(Error::new(eyre!("Aborted"), ErrorKind::Cancelled));
}
let (idx, _) = string_choices
.iter()
.enumerate()
.find(|(_, s)| s.as_str() == choice[0].as_str())
.ok_or_else(|| {
Error::new(
eyre!("selected choice does not appear in input"),
ErrorKind::Incoherent,
)
})?;
let choice = &choices[idx];
println!("{prompt} {choice}");
Ok(&choice)
}
pub async fn choose<'t, T: std::fmt::Display>(
prompt: &str,
choices: &'t [T],
) -> Result<&'t T, Error> {
choose_custom_display(prompt, choices, |t| t.to_string()).await
}

View File

@@ -5,14 +5,14 @@ use std::panic::{RefUnwindSafe, UnwindSafe};
use color_eyre::eyre::eyre;
use futures::future::BoxFuture;
use futures::{Future, FutureExt};
use imbl_value::{to_value, InternedString};
use imbl_value::{InternedString, to_value};
use patch_db::json_ptr::ROOT;
use crate::Error;
use crate::context::RpcContext;
use crate::db::model::Database;
use crate::prelude::*;
use crate::progress::PhaseProgressTrackerHandle;
use crate::Error;
mod v0_3_5;
mod v0_3_5_1;
@@ -51,8 +51,9 @@ mod v0_4_0_alpha_9;
mod v0_4_0_alpha_10;
mod v0_4_0_alpha_11;
mod v0_4_0_alpha_12;
pub type Current = v0_4_0_alpha_11::Version; // VERSION_BUMP
pub type Current = v0_4_0_alpha_12::Version; // VERSION_BUMP
impl Current {
#[instrument(skip(self, db))]
@@ -97,8 +98,8 @@ pub async fn post_init(
.as_server_info()
.as_post_init_migration_todos()
.de()?;
progress.start();
if !todos.is_empty() {
progress.set_total(todos.len() as u64);
while let Some((version, input)) = {
peek = ctx.db.peek().await;
peek.as_public()
@@ -121,7 +122,6 @@ pub async fn post_init(
})
.await
.result?;
progress += 1;
}
}
progress.complete();
@@ -166,7 +166,8 @@ enum Version {
V0_4_0_alpha_8(Wrapper<v0_4_0_alpha_8::Version>),
V0_4_0_alpha_9(Wrapper<v0_4_0_alpha_9::Version>),
V0_4_0_alpha_10(Wrapper<v0_4_0_alpha_10::Version>),
V0_4_0_alpha_11(Wrapper<v0_4_0_alpha_11::Version>), // VERSION_BUMP
V0_4_0_alpha_11(Wrapper<v0_4_0_alpha_11::Version>),
V0_4_0_alpha_12(Wrapper<v0_4_0_alpha_12::Version>), // VERSION_BUMP
Other(exver::Version),
}
@@ -220,7 +221,8 @@ impl Version {
Self::V0_4_0_alpha_8(v) => DynVersion(Box::new(v.0)),
Self::V0_4_0_alpha_9(v) => DynVersion(Box::new(v.0)),
Self::V0_4_0_alpha_10(v) => DynVersion(Box::new(v.0)),
Self::V0_4_0_alpha_11(v) => DynVersion(Box::new(v.0)), // VERSION_BUMP
Self::V0_4_0_alpha_11(v) => DynVersion(Box::new(v.0)),
Self::V0_4_0_alpha_12(v) => DynVersion(Box::new(v.0)), // VERSION_BUMP
Self::Other(v) => {
return Err(Error::new(
eyre!("unknown version {v}"),
@@ -266,7 +268,8 @@ impl Version {
Version::V0_4_0_alpha_8(Wrapper(x)) => x.semver(),
Version::V0_4_0_alpha_9(Wrapper(x)) => x.semver(),
Version::V0_4_0_alpha_10(Wrapper(x)) => x.semver(),
Version::V0_4_0_alpha_11(Wrapper(x)) => x.semver(), // VERSION_BUMP
Version::V0_4_0_alpha_11(Wrapper(x)) => x.semver(),
Version::V0_4_0_alpha_12(Wrapper(x)) => x.semver(), // VERSION_BUMP
Version::Other(x) => x.clone(),
}
}

View File

@@ -1,4 +1,4 @@
use std::collections::BTreeMap;
use std::collections::{BTreeMap, BTreeSet};
use std::ffi::OsStr;
use std::path::Path;
@@ -7,7 +7,7 @@ use const_format::formatcp;
use ed25519_dalek::SigningKey;
use exver::{PreReleaseSegment, VersionRange};
use imbl_value::{InternedString, json};
use models::{PackageId, ReplayId};
use models::{HostId, Id, PackageId, ReplayId};
use openssl::pkey::PKey;
use openssl::x509::X509;
use sqlx::postgres::PgConnectOptions;
@@ -24,8 +24,9 @@ use crate::disk::mount::filesystem::cifs::Cifs;
use crate::disk::mount::util::unmount;
use crate::hostname::Hostname;
use crate::net::forward::AvailablePorts;
use crate::net::host::Host;
use crate::net::keys::KeyStore;
use crate::net::tor::TorSecretKey;
use crate::net::tor::{OnionAddress, TorSecretKey};
use crate::notifications::Notifications;
use crate::prelude::*;
use crate::s9pk::merkle_archive::source::multi_cursor_file::MultiCursorFile;
@@ -93,69 +94,6 @@ async fn init_postgres(datadir: impl AsRef<Path>) -> Result<PgPool, Error> {
crate::disk::mount::util::bind(&db_dir, "/var/lib/postgresql", false).await?;
let pg_version_string = pg_version.to_string();
let pg_version_path = db_dir.join(&pg_version_string);
if exists
// maybe migrate
{
let incomplete_path = db_dir.join(format!("{pg_version}.migration.incomplete"));
if tokio::fs::metadata(&incomplete_path).await.is_ok() // previous migration was incomplete
&& tokio::fs::metadata(&pg_version_path).await.is_ok()
{
tokio::fs::remove_dir_all(&pg_version_path).await?;
}
if tokio::fs::metadata(&pg_version_path).await.is_err()
// need to migrate
{
let conf_dir = Path::new("/etc/postgresql").join(pg_version.to_string());
let conf_dir_tmp = {
let mut tmp = conf_dir.clone();
tmp.set_extension("tmp");
tmp
};
if tokio::fs::metadata(&conf_dir).await.is_ok() {
Command::new("mv")
.arg(&conf_dir)
.arg(&conf_dir_tmp)
.invoke(ErrorKind::Filesystem)
.await?;
}
let mut old_version = pg_version;
while old_version > 13
/* oldest pg version included in startos */
{
old_version -= 1;
let old_datadir = db_dir.join(old_version.to_string());
if tokio::fs::metadata(&old_datadir).await.is_ok() {
tokio::fs::File::create(&incomplete_path)
.await?
.sync_all()
.await?;
Command::new("pg_upgradecluster")
.arg(old_version.to_string())
.arg("main")
.invoke(crate::ErrorKind::Database)
.await?;
break;
}
}
if tokio::fs::metadata(&conf_dir).await.is_ok() {
if tokio::fs::metadata(&conf_dir).await.is_ok() {
tokio::fs::remove_dir_all(&conf_dir).await?;
}
Command::new("mv")
.arg(&conf_dir_tmp)
.arg(&conf_dir)
.invoke(ErrorKind::Filesystem)
.await?;
}
tokio::fs::remove_file(&incomplete_path).await?;
}
if tokio::fs::metadata(&incomplete_path).await.is_ok() {
unreachable!() // paranoia
}
}
Command::new("systemctl")
.arg("start")
.arg(format!("postgresql@{pg_version}-main.service"))
@@ -209,7 +147,12 @@ pub struct Version;
impl VersionT for Version {
type Previous = v0_3_5_2::Version;
type PreUpRes = (AccountInfo, SshKeys, CifsTargets);
type PreUpRes = (
AccountInfo,
SshKeys,
CifsTargets,
BTreeMap<PackageId, BTreeMap<HostId, TorSecretKey>>,
);
fn semver(self) -> exver::Version {
V0_3_6_alpha_0.clone()
}
@@ -224,9 +167,15 @@ impl VersionT for Version {
let cifs = previous_cifs(&pg).await?;
Ok((account, ssh_keys, cifs))
let tor_keys = previous_tor_keys(&pg).await?;
Ok((account, ssh_keys, cifs, tor_keys))
}
fn up(self, db: &mut Value, (account, ssh_keys, cifs): Self::PreUpRes) -> Result<Value, Error> {
fn up(
self,
db: &mut Value,
(account, ssh_keys, cifs, tor_keys): Self::PreUpRes,
) -> Result<Value, Error> {
let prev_package_data = db["package-data"].clone();
let wifi = json!({
@@ -288,9 +237,15 @@ impl VersionT for Version {
"ui": db["ui"],
});
let mut keystore = KeyStore::new(&account)?;
for key in tor_keys.values().flat_map(|v| v.values()) {
assert!(key.is_valid());
keystore.onion.insert(key.clone());
}
let private = {
let mut value = json!({});
value["keyStore"] = to_value(&KeyStore::new(&account)?)?;
value["keyStore"] = crate::dbg!(to_value(&keystore)?);
value["password"] = to_value(&account.password)?;
value["compatS9pkKey"] =
to_value(&crate::db::model::private::generate_developer_key())?;
@@ -373,6 +328,20 @@ impl VersionT for Version {
false
};
let onions = input[&*id]["installed"]["interface-addresses"]
.as_object()
.into_iter()
.flatten()
.filter_map(|(id, addrs)| {
addrs["tor-address"].as_str().map(|addr| {
Ok((
HostId::from(Id::try_from(id.clone())?),
addr.parse::<OnionAddress>()?,
))
})
})
.collect::<Result<BTreeMap<_, _>, Error>>()?;
if let Err(e) = async {
let package_s9pk = tokio::fs::File::open(path).await?;
let file = MultiCursorFile::open(&package_s9pk).await?;
@@ -390,19 +359,44 @@ impl VersionT for Version {
.await?
.await?;
if configured {
ctx.db
.mutate(|db| {
db.as_public_mut()
.as_package_data_mut()
.as_idx_mut(&id)
.or_not_found(&id)?
let to_sync = ctx
.db
.mutate(|db| {
let mut to_sync = BTreeSet::new();
let package = db
.as_public_mut()
.as_package_data_mut()
.as_idx_mut(&id)
.or_not_found(&id)?;
if configured {
package
.as_tasks_mut()
.remove(&ReplayId::from("needs-config"))
})
.await
.result?;
.remove(&ReplayId::from("needs-config"))?;
}
for (id, onion) in onions {
package
.as_hosts_mut()
.upsert(&id, || Ok(Host::new()))?
.as_onions_mut()
.mutate(|o| {
o.clear();
o.insert(onion);
Ok(())
})?;
to_sync.insert(id);
}
Ok(to_sync)
})
.await
.result?;
if let Some(service) = &*ctx.services.get(&id).await {
for host_id in to_sync {
service.sync_host(host_id.clone()).await?;
}
}
Ok::<_, Error>(())
}
.await
@@ -470,14 +464,12 @@ async fn previous_account_info(pg: &sqlx::Pool<sqlx::Postgres>) -> Result<Accoun
.try_get::<Option<Vec<u8>>, _>("tor_key")
.with_ctx(|_| (ErrorKind::Database, "tor_key"))?
{
<[u8; 64]>::try_from(bytes)
.map_err(|e| {
Error::new(
eyre!("expected vec of len 64, got len {}", e.len()),
ErrorKind::ParseDbField,
)
})
.with_ctx(|_| (ErrorKind::Database, "password.u8 64"))?
<[u8; 64]>::try_from(bytes).map_err(|e| {
Error::new(
eyre!("expected vec of len 64, got len {}", e.len()),
ErrorKind::ParseDbField,
)
})?
} else {
ed25519_expand_key(
&<[u8; 32]>::try_from(
@@ -490,8 +482,7 @@ async fn previous_account_info(pg: &sqlx::Pool<sqlx::Postgres>) -> Result<Accoun
eyre!("expected vec of len 32, got len {}", e.len()),
ErrorKind::ParseDbField,
)
})
.with_ctx(|_| (ErrorKind::Database, "password.u8 32"))?,
})?,
)
},
)?],
@@ -565,3 +556,69 @@ async fn previous_ssh_keys(pg: &sqlx::Pool<sqlx::Postgres>) -> Result<SshKeys, E
};
Ok(ssh_keys)
}
#[tracing::instrument(skip_all)]
async fn previous_tor_keys(
pg: &sqlx::Pool<sqlx::Postgres>,
) -> Result<BTreeMap<PackageId, BTreeMap<HostId, TorSecretKey>>, Error> {
let mut res = BTreeMap::<PackageId, BTreeMap<HostId, TorSecretKey>>::new();
let net_key_query = sqlx::query(r#"SELECT * FROM network_keys"#)
.fetch_all(pg)
.await
.with_kind(ErrorKind::Database)?;
for row in net_key_query {
let package_id: PackageId = row
.try_get::<String, _>("package")
.with_ctx(|_| (ErrorKind::Database, "network_keys::package"))?
.parse()?;
let interface_id: HostId = row
.try_get::<String, _>("interface")
.with_ctx(|_| (ErrorKind::Database, "network_keys::interface"))?
.parse()?;
let key = TorSecretKey::from_bytes(ed25519_expand_key(
&<[u8; 32]>::try_from(
row.try_get::<Vec<u8>, _>("key")
.with_ctx(|_| (ErrorKind::Database, "network_keys::key"))?,
)
.map_err(|e| {
Error::new(
eyre!("expected vec of len 32, got len {}", e.len()),
ErrorKind::ParseDbField,
)
})?,
))?;
res.entry(package_id).or_default().insert(interface_id, key);
}
let tor_key_query = sqlx::query(r#"SELECT * FROM tor"#)
.fetch_all(pg)
.await
.with_kind(ErrorKind::Database)?;
for row in tor_key_query {
let package_id: PackageId = row
.try_get::<String, _>("package")
.with_ctx(|_| (ErrorKind::Database, "tor::package"))?
.parse()?;
let interface_id: HostId = row
.try_get::<String, _>("interface")
.with_ctx(|_| (ErrorKind::Database, "tor::interface"))?
.parse()?;
let key = TorSecretKey::from_bytes(
<[u8; 64]>::try_from(
row.try_get::<Vec<u8>, _>("key")
.with_ctx(|_| (ErrorKind::Database, "tor::key"))?,
)
.map_err(|e| {
Error::new(
eyre!("expected vec of len 64, got len {}", e.len()),
ErrorKind::ParseDbField,
)
})?,
)?;
res.entry(package_id).or_default().insert(interface_id, key);
}
Ok(res)
}

View File

@@ -6,7 +6,7 @@ use models::GatewayId;
use serde::{Deserialize, Serialize};
use super::v0_3_5::V0_3_0_COMPAT;
use super::{v0_3_6_alpha_9, VersionT};
use super::{VersionT, v0_3_6_alpha_9};
use crate::net::host::address::PublicDomainConfig;
use crate::net::tor::OnionAddress;
use crate::prelude::*;

View File

@@ -50,10 +50,7 @@ impl VersionT for Version {
async fn post_up(self, ctx: &RpcContext, _input: Value) -> Result<(), Error> {
Command::new("systemd-firstboot")
.arg("--root=/media/startos/config/overlay/")
.arg(format!(
"--hostname={}",
ctx.account.read().await.hostname.0
))
.arg(ctx.account.peek(|a| format!("--hostname={}", a.hostname.0)))
.invoke(ErrorKind::ParseSysInfo)
.await?;
Ok(())

View File

@@ -115,7 +115,7 @@ impl VersionT for Version {
let manifest: Manifest = from_value(manifest.clone())?;
let id = manifest.id.clone();
let mut s9pk: S9pk<_> = S9pk::new_with_manifest(archive, None, manifest);
let s9pk_compat_key = ctx.account.read().await.developer_key.clone();
let s9pk_compat_key = ctx.account.peek(|a| a.developer_key.clone());
s9pk.as_archive_mut()
.set_signer(s9pk_compat_key, SIG_CONTEXT);
s9pk.serialize(&mut tmp_file, true).await?;

View File

@@ -31,7 +31,7 @@ impl VersionT for Version {
fn compat(self) -> &'static VersionRange {
&V0_3_0_COMPAT
}
#[instrument]
#[instrument(skip_all)]
fn up(self, db: &mut Value, _: Self::PreUpRes) -> Result<Value, Error> {
let default_gateway = db["public"]["serverInfo"]["network"]["networkInterfaces"]
.as_object()

View File

@@ -1,7 +1,7 @@
use exver::{PreReleaseSegment, VersionRange};
use super::v0_3_5::V0_3_0_COMPAT;
use super::{v0_4_0_alpha_10, VersionT};
use super::{VersionT, v0_4_0_alpha_10};
use crate::prelude::*;
lazy_static::lazy_static! {
@@ -27,7 +27,7 @@ impl VersionT for Version {
fn compat(self) -> &'static VersionRange {
&V0_3_0_COMPAT
}
#[instrument]
#[instrument(skip_all)]
fn up(self, db: &mut Value, _: Self::PreUpRes) -> Result<Value, Error> {
Ok(Value::Null)
}

View File

@@ -0,0 +1,85 @@
use std::collections::BTreeSet;
use exver::{PreReleaseSegment, VersionRange};
use imbl_value::InternedString;
use super::v0_3_5::V0_3_0_COMPAT;
use super::{VersionT, v0_4_0_alpha_11};
use crate::net::tor::TorSecretKey;
use crate::prelude::*;
lazy_static::lazy_static! {
static ref V0_4_0_alpha_12: exver::Version = exver::Version::new(
[0, 4, 0],
[PreReleaseSegment::String("alpha".into()), 12.into()]
);
}
#[derive(Clone, Copy, Debug, Default)]
pub struct Version;
impl VersionT for Version {
type Previous = v0_4_0_alpha_11::Version;
type PreUpRes = ();
async fn pre_up(self) -> Result<Self::PreUpRes, Error> {
Ok(())
}
fn semver(self) -> exver::Version {
V0_4_0_alpha_12.clone()
}
fn compat(self) -> &'static VersionRange {
&V0_3_0_COMPAT
}
#[instrument(skip_all)]
fn up(self, db: &mut Value, _: Self::PreUpRes) -> Result<Value, Error> {
let mut err = None;
let onion_store = db["private"]["keyStore"]["onion"]
.as_object_mut()
.or_not_found("private.keyStore.onion")?;
onion_store.retain(|o, v| match from_value::<TorSecretKey>(v.clone()) {
Ok(k) => k.is_valid() && &InternedString::from_display(&k.onion_address()) == o,
Err(e) => {
err = Some(e);
true
}
});
if let Some(e) = err {
return Err(e);
}
let allowed_addresses = onion_store.keys().cloned().collect::<BTreeSet<_>>();
let fix_host = |host: &mut Value| {
Ok::<_, Error>(
host["onions"]
.as_array_mut()
.or_not_found("host.onions")?
.retain(|addr| {
addr.as_str()
.map(|s| allowed_addresses.contains(s))
.unwrap_or(false)
}),
)
};
for (_, pde) in db["public"]["packageData"]
.as_object_mut()
.or_not_found("public.packageData")?
.iter_mut()
{
for (_, host) in pde["hosts"]
.as_object_mut()
.or_not_found("public.packageData[].hosts")?
.iter_mut()
{
fix_host(host)?;
}
}
fix_host(&mut db["public"]["serverInfo"]["network"]["host"])?;
db["private"]["keyStore"]["localCerts"] = db["private"]["keyStore"]["local_certs"].clone();
Ok(Value::Null)
}
fn down(self, _db: &mut Value) -> Result<(), Error> {
Ok(())
}
}

View File

@@ -29,7 +29,7 @@ impl VersionT for Version {
fn compat(self) -> &'static VersionRange {
&V0_3_0_COMPAT
}
#[instrument]
#[instrument(skip_all)]
fn up(self, db: &mut Value, _: Self::PreUpRes) -> Result<Value, Error> {
db["public"]["serverInfo"]
.as_object_mut()

View File

@@ -27,7 +27,7 @@ impl VersionT for Version {
fn compat(self) -> &'static VersionRange {
&V0_3_0_COMPAT
}
#[instrument]
#[instrument(skip_all)]
fn up(self, _db: &mut Value, _: Self::PreUpRes) -> Result<Value, Error> {
Ok(Value::Null)
}

View File

@@ -27,7 +27,7 @@ impl VersionT for Version {
fn compat(self) -> &'static VersionRange {
&V0_3_0_COMPAT
}
#[instrument]
#[instrument(skip_all)]
fn up(self, db: &mut Value, _: Self::PreUpRes) -> Result<Value, Error> {
let ui = db["public"]["ui"]
.as_object_mut()

View File

@@ -27,7 +27,7 @@ impl VersionT for Version {
fn compat(self) -> &'static VersionRange {
&V0_3_0_COMPAT
}
#[instrument]
#[instrument(skip_all)]
fn up(self, _db: &mut Value, _: Self::PreUpRes) -> Result<Value, Error> {
Ok(Value::Null)
}

View File

@@ -27,7 +27,7 @@ impl VersionT for Version {
fn compat(self) -> &'static VersionRange {
&V0_3_0_COMPAT
}
#[instrument]
#[instrument(skip_all)]
fn up(self, _db: &mut Value, _: Self::PreUpRes) -> Result<Value, Error> {
Ok(Value::Null)
}

Some files were not shown because too many files have changed in this diff Show More