From e0c9f8a5aad233755cd2341015adb8e9cd4a3ae1 Mon Sep 17 00:00:00 2001 From: Aiden McClelland <3732071+dr-bonez@users.noreply.github.com> Date: Thu, 7 Mar 2024 14:40:22 -0700 Subject: [PATCH] Feature/remove postgres (#2570) * wip: move postgres data to patchdb * wip * wip * wip * complete notifications and clean up warnings * fill in user agent * move os tor bindings to single call --- core/Cargo.lock | 98 ++--- core/models/Cargo.toml | 2 +- core/models/src/errors.rs | 5 + core/models/src/id/host.rs | 11 + core/models/src/id/mod.rs | 5 + core/models/src/id/package.rs | 11 + core/models/src/procedure_name.rs | 2 +- core/startos/Cargo.toml | 16 +- core/startos/src/account.rs | 125 +++--- core/startos/src/auth.rs | 134 +++--- core/startos/src/backup/backup_bulk.rs | 288 +++++++----- core/startos/src/backup/os.rs | 131 ++++-- core/startos/src/backup/restore.rs | 15 +- core/startos/src/backup/target/cifs.rs | 178 ++++---- core/startos/src/backup/target/mod.rs | 32 +- core/startos/src/context/config.rs | 10 +- core/startos/src/context/rpc.rs | 22 +- core/startos/src/context/setup.rs | 9 +- core/startos/src/control.rs | 6 +- core/startos/src/db/model.rs | 104 ++++- core/startos/src/db/prelude.rs | 237 +++++++--- core/startos/src/dependencies.rs | 10 +- core/startos/src/firmware.rs | 1 - core/startos/src/init.rs | 23 +- core/startos/src/install/mod.rs | 4 - core/startos/src/logs.rs | 2 +- core/startos/src/lxc/mod.rs | 17 + core/startos/src/middleware/auth.rs | 84 ++-- core/startos/src/net/dns.rs | 5 +- core/startos/src/net/forward.rs | 177 ++++++++ core/startos/src/net/host/address.rs | 9 + core/startos/src/net/host/binding.rs | 71 +++ core/startos/src/net/host/mod.rs | 93 +++- core/startos/src/net/host/multi.rs | 13 - core/startos/src/net/keys.rs | 399 +---------------- core/startos/src/net/mdns.rs | 8 +- core/startos/src/net/mod.rs | 14 +- core/startos/src/net/net_controller.rs | 413 ++++++++++-------- core/startos/src/net/ssl.rs | 288 ++++++------ core/startos/src/net/static_server.rs | 6 +- core/startos/src/net/tor.rs | 162 ++++--- core/startos/src/net/vhost.rs | 73 ++-- core/startos/src/notifications.rs | 340 ++++++-------- core/startos/src/registry/admin.rs | 4 +- core/startos/src/s9pk/merkle_archive/mod.rs | 2 - .../source/multi_cursor_file.rs | 4 +- core/startos/src/service/config.rs | 4 +- core/startos/src/service/mod.rs | 15 +- .../src/service/persistent_container.rs | 17 +- core/startos/src/service/rpc.rs | 2 +- .../src/service/service_effect_handler.rs | 12 +- core/startos/src/service/service_map.rs | 26 +- core/startos/src/service/transition/mod.rs | 4 +- core/startos/src/setup.rs | 42 +- core/startos/src/ssh.rs | 167 ++++--- core/startos/src/status/mod.rs | 6 + core/startos/src/update/mod.rs | 51 +-- core/startos/src/upload.rs | 8 +- core/startos/src/util/future.rs | 4 +- core/startos/src/util/serde.rs | 150 +++++++ core/startos/src/version/mod.rs | 214 ++++----- core/startos/src/version/v0_3_4.rs | 140 ------ core/startos/src/version/v0_3_4_1.rs | 31 -- core/startos/src/version/v0_3_4_2.rs | 31 -- core/startos/src/version/v0_3_4_3.rs | 31 -- core/startos/src/version/v0_3_4_4.rs | 43 -- core/startos/src/version/v0_3_5.rs | 103 +---- core/startos/src/version/v0_3_5_1.rs | 9 +- core/startos/src/version/v0_3_6.rs | 29 ++ core/startos/src/volume.rs | 10 + 70 files changed, 2429 insertions(+), 2383 deletions(-) create mode 100644 core/startos/src/net/forward.rs create mode 100644 core/startos/src/net/host/address.rs create mode 100644 core/startos/src/net/host/binding.rs delete mode 100644 core/startos/src/net/host/multi.rs delete mode 100644 core/startos/src/version/v0_3_4.rs delete mode 100644 core/startos/src/version/v0_3_4_1.rs delete mode 100644 core/startos/src/version/v0_3_4_2.rs delete mode 100644 core/startos/src/version/v0_3_4_3.rs delete mode 100644 core/startos/src/version/v0_3_4_4.rs create mode 100644 core/startos/src/version/v0_3_6.rs diff --git a/core/Cargo.lock b/core/Cargo.lock index 95b7f3ca9..5435fe2b0 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -304,7 +304,7 @@ checksum = "1236b4b292f6c4d6dc34604bb5120d85c3fe1d1aa596bd5cc52ca054d13e7b9e" dependencies = [ "async-trait", "axum-core 0.4.3", - "base64 0.21.7", + "base64", "bytes", "futures-util", "http 1.0.0", @@ -417,12 +417,6 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23ce669cd6c8588f79e15cf450314f9638f967fc5770ff1c7c1deb0925ea7cfa" -[[package]] -name = "base64" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" - [[package]] name = "base64" version = "0.21.7" @@ -533,7 +527,6 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" dependencies = [ - "block-padding", "generic-array", ] @@ -546,12 +539,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "block-padding" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d696c370c750c948ada61c69a0ee2cbbb9c50b1019ddb86d9317157a99c2cae" - [[package]] name = "brotli" version = "3.4.0" @@ -1044,16 +1031,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "crypto-mac" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25fab6889090c8133f3deb8f73ba3c65a7f456f66436fc012a1b1e272b1e103e" -dependencies = [ - "generic-array", - "subtle", -] - [[package]] name = "csv" version = "1.3.0" @@ -1847,7 +1824,7 @@ version = "7.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" dependencies = [ - "base64 0.21.7", + "base64", "byteorder", "flate2", "nom", @@ -1913,17 +1890,7 @@ version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" dependencies = [ - "hmac 0.12.1", -] - -[[package]] -name = "hmac" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a2a2320eb7ec0ebe8da8f744d7812d9fc4cb4d09344ac01898dbcb6a20ae69b" -dependencies = [ - "crypto-mac", - "digest 0.9.0", + "hmac", ] [[package]] @@ -2125,6 +2092,15 @@ dependencies = [ "cc", ] +[[package]] +name = "id-pool" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d0df4d8a768821ee4aa2e0353f67125c4586f0e13adbf95b8ebbf8d8fdb344" +dependencies = [ + "serde", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -2412,7 +2388,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd20997283339a19226445db97d632c8dc7adb6b8172537fe0e9e540fb141df2" dependencies = [ "anyhow", - "base64 0.21.7", + "base64", "flate2", "once_cell", "openssl", @@ -2684,7 +2660,7 @@ dependencies = [ name = "models" version = "0.1.0" dependencies = [ - "base64 0.21.7", + "base64", "color-eyre", "ed25519-dalek 2.1.1", "emver", @@ -2962,7 +2938,7 @@ version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c75a0ec2d1b302412fb503224289325fcc0e44600176864804c7211b055cfd58" dependencies = [ - "base64 0.21.7", + "base64", "byteorder", "md-5", "sha2 0.10.8", @@ -3156,7 +3132,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" dependencies = [ "digest 0.10.7", - "hmac 0.12.1", + "hmac", ] [[package]] @@ -3609,7 +3585,7 @@ version = "0.11.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251" dependencies = [ - "base64 0.21.7", + "base64", "bytes", "cookie 0.17.0", "cookie_store", @@ -3666,7 +3642,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" dependencies = [ - "hmac 0.12.1", + "hmac", "subtle", ] @@ -3827,7 +3803,7 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d9848531d60c9cbbcf9d166c885316c24bc0e2a9d3eba0956bb6cbbd79bc6e8" dependencies = [ - "base64 0.21.7", + "base64", "blake2b_simd", "constant_time_eq", ] @@ -3891,7 +3867,7 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ - "base64 0.21.7", + "base64", ] [[package]] @@ -4128,7 +4104,7 @@ version = "3.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "15d167997bd841ec232f5b2b8e0e26606df2e7caa4c31b95ea9ca52b200bd270" dependencies = [ - "base64 0.21.7", + "base64", "chrono", "hex", "indexmap 1.9.3", @@ -4202,14 +4178,12 @@ dependencies = [ [[package]] name = "sha3" -version = "0.9.1" +version = "0.10.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f81199417d4e5de3f04b1e871023acea7389672c4135918f05aa9cbf2f2fa809" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" dependencies = [ - "block-buffer 0.9.0", - "digest 0.9.0", + "digest 0.10.7", "keccak", - "opaque-debug", ] [[package]] @@ -4455,7 +4429,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4" dependencies = [ "atoi", - "base64 0.21.7", + "base64", "bitflags 2.4.2", "byteorder", "bytes", @@ -4471,7 +4445,7 @@ dependencies = [ "generic-array", "hex", "hkdf", - "hmac 0.12.1", + "hmac", "itoa", "log", "md-5", @@ -4498,7 +4472,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24" dependencies = [ "atoi", - "base64 0.21.7", + "base64", "bitflags 2.4.2", "byteorder", "chrono", @@ -4511,7 +4485,7 @@ dependencies = [ "futures-util", "hex", "hkdf", - "hmac 0.12.1", + "hmac", "home", "itoa", "log", @@ -4635,7 +4609,7 @@ dependencies = [ "axum 0.7.4", "axum-server", "base32", - "base64 0.21.7", + "base64", "base64ct", "basic-cookies", "blake3", @@ -4660,8 +4634,9 @@ dependencies = [ "gpt", "helpers", "hex", - "hmac 0.12.1", + "hmac", "http 1.0.0", + "id-pool", "imbl", "imbl-value", "include_dir", @@ -5215,7 +5190,7 @@ dependencies = [ "async-stream", "async-trait", "axum 0.6.20", - "base64 0.21.7", + "base64", "bytes", "h2 0.3.24", "http 0.2.11", @@ -5236,19 +5211,18 @@ dependencies = [ [[package]] name = "torut" version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99febc413f26cf855b3a309c5872edff5c31e0ffe9c2fce5681868761df36f69" +source = "git+https://github.com/Start9Labs/torut.git?branch=update/dependencies#2b6fa9528d22e0b276132bccf7f2e9308f84b2c7" dependencies = [ "base32", - "base64 0.13.1", + "base64", "derive_more", "ed25519-dalek 1.0.1", "hex", - "hmac 0.11.0", + "hmac", "rand 0.7.3", "serde", "serde_derive", - "sha2 0.9.9", + "sha2 0.10.8", "sha3", "tokio", ] diff --git a/core/models/Cargo.toml b/core/models/Cargo.toml index c6fc76f55..250ba22a7 100644 --- a/core/models/Cargo.toml +++ b/core/models/Cargo.toml @@ -34,6 +34,6 @@ sqlx = { version = "0.7.2", features = [ ssh-key = "0.6.2" thiserror = "1.0" tokio = { version = "1", features = ["full"] } -torut = "0.2.1" +torut = { git = "https://github.com/Start9Labs/torut.git", branch = "update/dependencies" } tracing = "0.1.39" yasi = "0.1.5" diff --git a/core/models/src/errors.rs b/core/models/src/errors.rs index 15bc90b9f..2362b6dba 100644 --- a/core/models/src/errors.rs +++ b/core/models/src/errors.rs @@ -207,6 +207,11 @@ impl Error { } } } +impl From for Error { + fn from(value: std::convert::Infallible) -> Self { + match value {} + } +} impl From for Error { fn from(err: InvalidId) -> Self { Error::new(err, ErrorKind::InvalidPackageId) diff --git a/core/models/src/id/host.rs b/core/models/src/id/host.rs index 91abd56e7..6bca7d0ff 100644 --- a/core/models/src/id/host.rs +++ b/core/models/src/id/host.rs @@ -2,6 +2,7 @@ use std::path::Path; use std::str::FromStr; use serde::{Deserialize, Deserializer, Serialize}; +use yasi::InternedString; use crate::{Id, InvalidId}; @@ -18,6 +19,16 @@ impl From for HostId { Self(id) } } +impl From for Id { + fn from(value: HostId) -> Self { + value.0 + } +} +impl From for InternedString { + fn from(value: HostId) -> Self { + value.0.into() + } +} impl std::fmt::Display for HostId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", &self.0) diff --git a/core/models/src/id/mod.rs b/core/models/src/id/mod.rs index 068955336..efbe1f818 100644 --- a/core/models/src/id/mod.rs +++ b/core/models/src/id/mod.rs @@ -59,6 +59,11 @@ impl TryFrom<&str> for Id { } } } +impl From for InternedString { + fn from(value: Id) -> Self { + value.0 + } +} impl std::ops::Deref for Id { type Target = str; fn deref(&self) -> &Self::Target { diff --git a/core/models/src/id/package.rs b/core/models/src/id/package.rs index 14c29d88b..060f541c3 100644 --- a/core/models/src/id/package.rs +++ b/core/models/src/id/package.rs @@ -3,6 +3,7 @@ use std::path::Path; use std::str::FromStr; use serde::{Deserialize, Serialize, Serializer}; +use yasi::InternedString; use crate::{Id, InvalidId, SYSTEM_ID}; @@ -22,6 +23,16 @@ impl From for PackageId { PackageId(id) } } +impl From for Id { + fn from(value: PackageId) -> Self { + value.0 + } +} +impl From for InternedString { + fn from(value: PackageId) -> Self { + value.0.into() + } +} impl std::ops::Deref for PackageId { type Target = str; fn deref(&self) -> &Self::Target { diff --git a/core/models/src/procedure_name.rs b/core/models/src/procedure_name.rs index 841f8df7d..bf69b06b8 100644 --- a/core/models/src/procedure_name.rs +++ b/core/models/src/procedure_name.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -use crate::{ActionId, HealthCheckId, PackageId}; +use crate::ActionId; #[derive(Debug, Clone, Serialize, Deserialize)] pub enum ProcedureName { diff --git a/core/startos/Cargo.toml b/core/startos/Cargo.toml index 3fce87089..fd36c3de6 100644 --- a/core/startos/Cargo.toml +++ b/core/startos/Cargo.toml @@ -87,14 +87,10 @@ helpers = { path = "../helpers" } hex = "0.4.3" hmac = "0.12.1" http = "1.0.0" -# http-body-util = "0.1.0" -# hyper = { version = "1.1.0", features = ["full"] } -# hyper-util = { version = "0.1.2", features = [ -# "server", -# "server-auto", -# "tokio", -# ] } -# hyper-ws-listener = "0.3.0" +id-pool = { version = "0.2.2", default-features = false, features = [ + "serde", + "u16", +] } imbl = "2.0.2" imbl-value = { git = "https://github.com/Start9Labs/imbl-value.git" } include_dir = "0.7.3" @@ -169,7 +165,9 @@ tokio-stream = { version = "0.1.14", features = ["io-util", "sync", "net"] } tokio-tar = { git = "https://github.com/dr-bonez/tokio-tar.git" } tokio-tungstenite = { version = "0.21.0", features = ["native-tls"] } tokio-util = { version = "0.7.9", features = ["io"] } -torut = "0.2.1" +torut = { git = "https://github.com/Start9Labs/torut.git", branch = "update/dependencies", features = [ + "serialize", +] } tracing = "0.1.39" tracing-error = "0.2.0" tracing-futures = "0.2.5" diff --git a/core/startos/src/account.rs b/core/startos/src/account.rs index cb08a0d53..e074d301d 100644 --- a/core/startos/src/account.rs +++ b/core/startos/src/account.rs @@ -1,15 +1,14 @@ use std::time::SystemTime; -use ed25519_dalek::SecretKey; use openssl::pkey::{PKey, Private}; use openssl::x509::X509; -use sqlx::PgExecutor; +use torut::onion::TorSecretKeyV3; +use crate::db::model::DatabaseModel; use crate::hostname::{generate_hostname, generate_id, Hostname}; -use crate::net::keys::Key; use crate::net::ssl::{generate_key, make_root_cert}; use crate::prelude::*; -use crate::util::crypto::ed25519_expand_key; +use crate::util::serde::Pem; fn hash_password(password: &str) -> Result { argon2::hash_encoded( @@ -25,103 +24,83 @@ pub struct AccountInfo { pub server_id: String, pub hostname: Hostname, pub password: String, - pub key: Key, + pub tor_key: TorSecretKeyV3, pub root_ca_key: PKey, pub root_ca_cert: X509, + pub ssh_key: ssh_key::PrivateKey, } impl AccountInfo { pub fn new(password: &str, start_time: SystemTime) -> Result { let server_id = generate_id(); let hostname = generate_hostname(); + let tor_key = TorSecretKeyV3::generate(); let root_ca_key = generate_key()?; let root_ca_cert = make_root_cert(&root_ca_key, &hostname, start_time)?; + let ssh_key = ssh_key::PrivateKey::from(ssh_key::private::Ed25519Keypair::random( + &mut rand::thread_rng(), + )); Ok(Self { server_id, hostname, password: hash_password(password)?, - key: Key::new(None), + tor_key, root_ca_key, root_ca_cert, + ssh_key, }) } - pub async fn load(secrets: impl PgExecutor<'_>) -> Result { - let r = sqlx::query!("SELECT * FROM account WHERE id = 0") - .fetch_one(secrets) - .await?; - - let server_id = r.server_id.unwrap_or_else(generate_id); - let hostname = r.hostname.map(Hostname).unwrap_or_else(generate_hostname); - let password = r.password; - let network_key = SecretKey::try_from(r.network_key).map_err(|e| { - Error::new( - eyre!("expected vec of len 32, got len {}", e.len()), - ErrorKind::ParseDbField, - ) - })?; - let tor_key = if let Some(k) = &r.tor_key { - <[u8; 64]>::try_from(&k[..]).map_err(|_| { - Error::new( - eyre!("expected vec of len 64, got len {}", k.len()), - ErrorKind::ParseDbField, - ) - })? - } else { - ed25519_expand_key(&network_key) - }; - let key = Key::from_pair(None, network_key, tor_key); - let root_ca_key = PKey::private_key_from_pem(r.root_ca_key_pem.as_bytes())?; - let root_ca_cert = X509::from_pem(r.root_ca_cert_pem.as_bytes())?; + pub fn load(db: &DatabaseModel) -> Result { + let server_id = db.as_public().as_server_info().as_id().de()?; + let hostname = Hostname(db.as_public().as_server_info().as_hostname().de()?); + let password = db.as_private().as_password().de()?; + let key_store = db.as_private().as_key_store(); + let tor_addr = db.as_public().as_server_info().as_onion_address().de()?; + let tor_key = key_store.as_onion().get_key(&tor_addr)?; + let cert_store = key_store.as_local_certs(); + let root_ca_key = cert_store.as_root_key().de()?.0; + let root_ca_cert = cert_store.as_root_cert().de()?.0; + let ssh_key = db.as_private().as_ssh_privkey().de()?.0; Ok(Self { server_id, hostname, password, - key, + tor_key, root_ca_key, root_ca_cert, + ssh_key, }) } - pub async fn save(&self, secrets: impl PgExecutor<'_>) -> Result<(), Error> { - let server_id = self.server_id.as_str(); - let hostname = self.hostname.0.as_str(); - let password = self.password.as_str(); - let network_key = self.key.as_bytes(); - let network_key = network_key.as_slice(); - let root_ca_key = String::from_utf8(self.root_ca_key.private_key_to_pem_pkcs8()?)?; - let root_ca_cert = String::from_utf8(self.root_ca_cert.to_pem()?)?; - - sqlx::query!( - r#" - INSERT INTO account ( - id, - server_id, - hostname, - password, - network_key, - root_ca_key_pem, - root_ca_cert_pem - ) VALUES ( - 0, $1, $2, $3, $4, $5, $6 - ) ON CONFLICT (id) DO UPDATE SET - server_id = EXCLUDED.server_id, - hostname = EXCLUDED.hostname, - password = EXCLUDED.password, - network_key = EXCLUDED.network_key, - root_ca_key_pem = EXCLUDED.root_ca_key_pem, - root_ca_cert_pem = EXCLUDED.root_ca_cert_pem - "#, - server_id, - hostname, - password, - network_key, - root_ca_key, - root_ca_cert, - ) - .execute(secrets) - .await?; - + pub fn save(&self, db: &mut DatabaseModel) -> Result<(), Error> { + let server_info = db.as_public_mut().as_server_info_mut(); + server_info.as_id_mut().ser(&self.server_id)?; + server_info.as_hostname_mut().ser(&self.hostname.0)?; + server_info + .as_lan_address_mut() + .ser(&self.hostname.lan_address().parse()?)?; + server_info + .as_pubkey_mut() + .ser(&self.ssh_key.public_key().to_openssh()?)?; + let onion_address = self.tor_key.public().get_onion_address(); + server_info.as_onion_address_mut().ser(&onion_address)?; + server_info + .as_tor_address_mut() + .ser(&format!("https://{onion_address}").parse()?)?; + db.as_private_mut().as_password_mut().ser(&self.password)?; + db.as_private_mut() + .as_ssh_privkey_mut() + .ser(Pem::new_ref(&self.ssh_key))?; + let key_store = db.as_private_mut().as_key_store_mut(); + key_store.as_onion_mut().insert_key(&self.tor_key)?; + let cert_store = key_store.as_local_certs_mut(); + cert_store + .as_root_key_mut() + .ser(Pem::new_ref(&self.root_ca_key))?; + cert_store + .as_root_cert_mut() + .ser(Pem::new_ref(&self.root_ca_cert))?; Ok(()) } diff --git a/core/startos/src/auth.rs b/core/startos/src/auth.rs index 891f390bc..21cfbd101 100644 --- a/core/startos/src/auth.rs +++ b/core/startos/src/auth.rs @@ -1,17 +1,17 @@ use std::collections::BTreeMap; use chrono::{DateTime, Utc}; -use clap::{ArgMatches, Parser}; +use clap::Parser; use color_eyre::eyre::eyre; use imbl_value::{json, InternedString}; use josekit::jwk::Jwk; use rpc_toolkit::yajrc::RpcError; use rpc_toolkit::{command, from_fn_async, AnyContext, CallRemote, HandlerExt, ParentHandler}; use serde::{Deserialize, Serialize}; -use sqlx::{Executor, Postgres}; use tracing::instrument; use crate::context::{CliContext, RpcContext}; +use crate::db::model::DatabaseModel; use crate::middleware::auth::{ AsLogoutSessionId, HasLoggedOutSessions, HashSessionToken, LoginRes, }; @@ -19,6 +19,25 @@ use crate::prelude::*; use crate::util::crypto::EncryptedWire; use crate::util::serde::{display_serializable, HandlerExtSerde, WithIoFormat}; use crate::{ensure_code, Error, ResultExt}; + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct Sessions(pub BTreeMap); +impl Sessions { + pub fn new() -> Self { + Self(BTreeMap::new()) + } +} +impl Map for Sessions { + type Key = InternedString; + type Value = Session; + fn key_str(key: &Self::Key) -> Result, Error> { + Ok(key) + } + fn key_string(key: &Self::Key) -> Result { + Ok(key.clone()) + } +} + #[derive(Clone, Serialize, Deserialize)] #[serde(untagged)] pub enum PasswordType { @@ -95,16 +114,6 @@ pub fn auth() -> ParentHandler { ) } -pub fn cli_metadata() -> Value { - imbl_value::json!({ - "platforms": ["cli"], - }) -} - -pub fn parse_metadata(_: &str, _: &ArgMatches) -> Result { - Ok(cli_metadata()) -} - #[test] fn gen_pwd() { println!( @@ -163,14 +172,8 @@ pub fn check_password(hash: &str, password: &str) -> Result<(), Error> { Ok(()) } -pub async fn check_password_against_db(secrets: &mut Ex, password: &str) -> Result<(), Error> -where - for<'a> &'a mut Ex: Executor<'a, Database = Postgres>, -{ - let pw_hash = sqlx::query!("SELECT password FROM account") - .fetch_one(secrets) - .await? - .password; +pub fn check_password_against_db(db: &DatabaseModel, password: &str) -> Result<(), Error> { + let pw_hash = db.as_private().as_password().de()?; check_password(&pw_hash, password)?; Ok(()) } @@ -180,7 +183,8 @@ where #[command(rename_all = "kebab-case")] pub struct LoginParams { password: Option, - #[arg(skip = cli_metadata())] + #[serde(default)] + user_agent: Option, #[serde(default)] metadata: Value, } @@ -188,26 +192,31 @@ pub struct LoginParams { #[instrument(skip_all)] pub async fn login_impl( ctx: RpcContext, - LoginParams { password, metadata }: LoginParams, -) -> Result { - let password = password.unwrap_or_default().decrypt(&ctx)?; - let mut handle = ctx.secret_store.acquire().await?; - check_password_against_db(handle.as_mut(), &password).await?; - - let hash_token = HashSessionToken::new(); - let user_agent = "".to_string(); // todo!() as String; - let metadata = serde_json::to_string(&metadata).with_kind(crate::ErrorKind::Database)?; - let hash_token_hashed = hash_token.hashed(); - sqlx::query!( - "INSERT INTO session (id, user_agent, metadata) VALUES ($1, $2, $3)", - hash_token_hashed, + LoginParams { + password, user_agent, metadata, - ) - .execute(handle.as_mut()) - .await?; + }: LoginParams, +) -> Result { + let password = password.unwrap_or_default().decrypt(&ctx)?; - Ok(hash_token.to_login_res()) + ctx.db + .mutate(|db| { + check_password_against_db(db, &password)?; + let hash_token = HashSessionToken::new(); + db.as_private_mut().as_sessions_mut().insert( + hash_token.hashed(), + &Session { + logged_in: Utc::now(), + last_active: Utc::now(), + user_agent, + metadata, + }, + )?; + + Ok(hash_token.to_login_res()) + }) + .await } #[derive(Deserialize, Serialize, Parser)] @@ -226,20 +235,20 @@ pub async fn logout( )) } -#[derive(Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "kebab-case")] pub struct Session { - logged_in: DateTime, - last_active: DateTime, - user_agent: Option, - metadata: Value, + pub logged_in: DateTime, + pub last_active: DateTime, + pub user_agent: Option, + pub metadata: Value, } #[derive(Deserialize, Serialize)] #[serde(rename_all = "kebab-case")] pub struct SessionList { - current: String, - sessions: BTreeMap, + current: InternedString, + sessions: Sessions, } pub fn session() -> ParentHandler { @@ -277,7 +286,7 @@ fn display_sessions(params: WithIoFormat, arg: SessionList) { "USER AGENT", "METADATA", ]); - for (id, session) in arg.sessions { + for (id, session) in arg.sessions.0 { let mut row = row![ &id, &format!("{}", session.logged_in), @@ -310,33 +319,11 @@ pub async fn list( ListParams { session, .. }: ListParams, ) -> Result { Ok(SessionList { - current: HashSessionToken::from_token(session).hashed().to_owned(), - sessions: sqlx::query!( - "SELECT * FROM session WHERE logged_out IS NULL OR logged_out > CURRENT_TIMESTAMP" - ) - .fetch_all(ctx.secret_store.acquire().await?.as_mut()) - .await? - .into_iter() - .map(|row| { - Ok(( - row.id, - Session { - logged_in: DateTime::from_utc(row.logged_in, Utc), - last_active: DateTime::from_utc(row.last_active, Utc), - user_agent: row.user_agent, - metadata: serde_json::from_str(&row.metadata) - .with_kind(crate::ErrorKind::Database)?, - }, - )) - }) - .collect::>()?, + current: HashSessionToken::from_token(session).hashed().clone(), + sessions: ctx.db.peek().await.into_private().into_sessions().de()?, }) } -fn parse_comma_separated(arg: &str, _: &ArgMatches) -> Result, RpcError> { - Ok(arg.split(",").map(|s| s.trim().to_owned()).collect()) -} - #[derive(Debug, Clone, Serialize, Deserialize)] struct KillSessionId(InternedString); @@ -433,14 +420,17 @@ pub async fn reset_password_impl( )); } account.set_password(&new_password)?; - account.save(&ctx.secret_store).await?; let account_password = &account.password; + let account = account.clone(); ctx.db .mutate(|d| { d.as_public_mut() .as_server_info_mut() .as_password_hash_mut() - .ser(account_password) + .ser(account_password)?; + account.save(d)?; + + Ok(()) }) .await } diff --git a/core/startos/src/backup/backup_bulk.rs b/core/startos/src/backup/backup_bulk.rs index 4660ab4bc..4f633d7ae 100644 --- a/core/startos/src/backup/backup_bulk.rs +++ b/core/startos/src/backup/backup_bulk.rs @@ -1,5 +1,4 @@ use std::collections::BTreeMap; -use std::panic::UnwindSafe; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -19,11 +18,11 @@ use crate::auth::check_password_against_db; use crate::backup::os::OsBackup; use crate::backup::{BackupReport, ServerBackupReport}; use crate::context::RpcContext; -use crate::db::model::BackupProgress; +use crate::db::model::{BackupProgress, DatabaseModel}; use crate::disk::mount::backup::BackupMountGuard; use crate::disk::mount::filesystem::ReadWrite; use crate::disk::mount::guard::{GenericMountGuard, TmpMountGuard}; -use crate::notifications::NotificationLevel; +use crate::notifications::{notify, NotificationLevel}; use crate::prelude::*; use crate::util::io::dir_copy; use crate::util::serde::IoFormat; @@ -41,6 +40,111 @@ pub struct BackupParams { password: crate::auth::PasswordType, } +struct BackupStatusGuard(Option); +impl BackupStatusGuard { + fn new(db: PatchDb) -> Self { + Self(Some(db)) + } + async fn handle_result( + mut self, + result: Result, Error>, + ) -> Result<(), Error> { + if let Some(db) = self.0.as_ref() { + db.mutate(|v| { + v.as_public_mut() + .as_server_info_mut() + .as_status_info_mut() + .as_backup_progress_mut() + .ser(&None) + }) + .await?; + } + if let Some(db) = self.0.take() { + match result { + Ok(report) if report.iter().all(|(_, rep)| rep.error.is_none()) => { + db.mutate(|db| { + notify( + db, + None, + NotificationLevel::Success, + "Backup Complete".to_owned(), + "Your backup has completed".to_owned(), + BackupReport { + server: ServerBackupReport { + attempted: true, + error: None, + }, + packages: report, + }, + ) + }) + .await + } + Ok(report) => { + db.mutate(|db| { + notify( + db, + None, + NotificationLevel::Warning, + "Backup Complete".to_owned(), + "Your backup has completed, but some package(s) failed to backup" + .to_owned(), + BackupReport { + server: ServerBackupReport { + attempted: true, + error: None, + }, + packages: report, + }, + ) + }) + .await + } + Err(e) => { + tracing::error!("Backup Failed: {}", e); + tracing::debug!("{:?}", e); + let err_string = e.to_string(); + db.mutate(|db| { + notify( + db, + None, + NotificationLevel::Error, + "Backup Failed".to_owned(), + "Your backup failed to complete.".to_owned(), + BackupReport { + server: ServerBackupReport { + attempted: true, + error: Some(err_string), + }, + packages: BTreeMap::new(), + }, + ) + }) + .await + } + }?; + } + Ok(()) + } +} +impl Drop for BackupStatusGuard { + fn drop(&mut self) { + if let Some(db) = self.0.take() { + tokio::spawn(async move { + db.mutate(|v| { + v.as_public_mut() + .as_server_info_mut() + .as_status_info_mut() + .as_backup_progress_mut() + .ser(&None) + }) + .await + .unwrap() + }); + } + } +} + #[instrument(skip(ctx, old_password, password))] pub async fn backup_all( ctx: RpcContext, @@ -57,139 +161,81 @@ pub async fn backup_all( .clone() .decrypt(&ctx)?; let password = password.decrypt(&ctx)?; - check_password_against_db(ctx.secret_store.acquire().await?.as_mut(), &password).await?; - let fs = target_id - .load(ctx.secret_store.acquire().await?.as_mut()) - .await?; + + let ((fs, package_ids), status_guard) = ( + ctx.db + .mutate(|db| { + check_password_against_db(db, &password)?; + let fs = target_id.load(db)?; + let package_ids = if let Some(ids) = package_ids { + ids.into_iter().collect() + } else { + db.as_public() + .as_package_data() + .as_entries()? + .into_iter() + .filter(|(_, m)| m.expect_as_installed().is_ok()) + .map(|(id, _)| id) + .collect() + }; + assure_backing_up(db, &package_ids)?; + Ok((fs, package_ids)) + }) + .await?, + BackupStatusGuard::new(ctx.db.clone()), + ); + let mut backup_guard = BackupMountGuard::mount( TmpMountGuard::mount(&fs, ReadWrite).await?, &old_password_decrypted, ) .await?; - let package_ids = if let Some(ids) = package_ids { - ids.into_iter().collect() - } else { - todo!("all installed packages"); - }; if old_password.is_some() { backup_guard.change_password(&password)?; } - assure_backing_up(&ctx.db, &package_ids).await?; tokio::task::spawn(async move { - let backup_res = perform_backup(&ctx, backup_guard, &package_ids).await; - match backup_res { - Ok(report) if report.iter().all(|(_, rep)| rep.error.is_none()) => ctx - .notification_manager - .notify( - ctx.db.clone(), - None, - NotificationLevel::Success, - "Backup Complete".to_owned(), - "Your backup has completed".to_owned(), - BackupReport { - server: ServerBackupReport { - attempted: true, - error: None, - }, - packages: report, - }, - None, - ) - .await - .expect("failed to send notification"), - Ok(report) => ctx - .notification_manager - .notify( - ctx.db.clone(), - None, - NotificationLevel::Warning, - "Backup Complete".to_owned(), - "Your backup has completed, but some package(s) failed to backup".to_owned(), - BackupReport { - server: ServerBackupReport { - attempted: true, - error: None, - }, - packages: report, - }, - None, - ) - .await - .expect("failed to send notification"), - Err(e) => { - tracing::error!("Backup Failed: {}", e); - tracing::debug!("{:?}", e); - ctx.notification_manager - .notify( - ctx.db.clone(), - None, - NotificationLevel::Error, - "Backup Failed".to_owned(), - "Your backup failed to complete.".to_owned(), - BackupReport { - server: ServerBackupReport { - attempted: true, - error: Some(e.to_string()), - }, - packages: BTreeMap::new(), - }, - None, - ) - .await - .expect("failed to send notification"); - } - } - ctx.db - .mutate(|v| { - v.as_public_mut() - .as_server_info_mut() - .as_status_info_mut() - .as_backup_progress_mut() - .ser(&None) - }) - .await?; - Ok::<(), Error>(()) + status_guard + .handle_result(perform_backup(&ctx, backup_guard, &package_ids).await) + .await + .unwrap(); }); Ok(()) } #[instrument(skip(db, packages))] -async fn assure_backing_up( - db: &PatchDb, - packages: impl IntoIterator + UnwindSafe + Send, +fn assure_backing_up<'a>( + db: &mut DatabaseModel, + packages: impl IntoIterator, ) -> Result<(), Error> { - db.mutate(|v| { - let backing_up = v - .as_public_mut() - .as_server_info_mut() - .as_status_info_mut() - .as_backup_progress_mut(); - if backing_up - .clone() - .de()? - .iter() - .flat_map(|x| x.values()) - .fold(false, |acc, x| { - if !x.complete { - return true; - } - acc - }) - { - return Err(Error::new( - eyre!("Server is already backing up!"), - ErrorKind::InvalidRequest, - )); - } - backing_up.ser(&Some( - packages - .into_iter() - .map(|x| (x.clone(), BackupProgress { complete: false })) - .collect(), - ))?; - Ok(()) - }) - .await + let backing_up = db + .as_public_mut() + .as_server_info_mut() + .as_status_info_mut() + .as_backup_progress_mut(); + if backing_up + .clone() + .de()? + .iter() + .flat_map(|x| x.values()) + .fold(false, |acc, x| { + if !x.complete { + return true; + } + acc + }) + { + return Err(Error::new( + eyre!("Server is already backing up!"), + ErrorKind::InvalidRequest, + )); + } + backing_up.ser(&Some( + packages + .into_iter() + .map(|x| (x.clone(), BackupProgress { complete: false })) + .collect(), + ))?; + Ok(()) } #[instrument(skip(ctx, backup_guard))] diff --git a/core/startos/src/backup/os.rs b/core/startos/src/backup/os.rs index 5ab8bd12e..6848473a7 100644 --- a/core/startos/src/backup/os.rs +++ b/core/startos/src/backup/os.rs @@ -1,13 +1,15 @@ -use openssl::pkey::PKey; +use openssl::pkey::{PKey, Private}; use openssl::x509::X509; use patch_db::Value; use serde::{Deserialize, Serialize}; +use ssh_key::private::Ed25519Keypair; +use torut::onion::TorSecretKeyV3; use crate::account::AccountInfo; use crate::hostname::{generate_hostname, generate_id, Hostname}; -use crate::net::keys::Key; use crate::prelude::*; -use crate::util::serde::Base64; +use crate::util::crypto::ed25519_expand_key; +use crate::util::serde::{Base32, Base64, Pem}; pub struct OsBackup { pub account: AccountInfo, @@ -19,19 +21,23 @@ impl<'de> Deserialize<'de> for OsBackup { D: serde::Deserializer<'de>, { let tagged = OsBackupSerDe::deserialize(deserializer)?; - match tagged.version { + Ok(match tagged.version { 0 => patch_db::value::from_value::(tagged.rest) .map_err(serde::de::Error::custom)? .project() - .map_err(serde::de::Error::custom), + .map_err(serde::de::Error::custom)?, 1 => patch_db::value::from_value::(tagged.rest) .map_err(serde::de::Error::custom)? - .project() - .map_err(serde::de::Error::custom), - v => Err(serde::de::Error::custom(&format!( - "Unknown backup version {v}" - ))), - } + .project(), + 2 => patch_db::value::from_value::(tagged.rest) + .map_err(serde::de::Error::custom)? + .project(), + v => { + return Err(serde::de::Error::custom(&format!( + "Unknown backup version {v}" + ))) + } + }) } } impl Serialize for OsBackup { @@ -40,11 +46,9 @@ impl Serialize for OsBackup { S: serde::Serializer, { OsBackupSerDe { - version: 1, - rest: patch_db::value::to_value( - &OsBackupV1::unproject(self).map_err(serde::ser::Error::custom)?, - ) - .map_err(serde::ser::Error::custom)?, + version: 2, + rest: patch_db::value::to_value(&OsBackupV2::unproject(self)) + .map_err(serde::ser::Error::custom)?, } .serialize(serializer) } @@ -62,10 +66,10 @@ struct OsBackupSerDe { #[derive(Deserialize)] #[serde(rename = "kebab-case")] struct OsBackupV0 { - // tor_key: Base32<[u8; 64]>, - root_ca_key: String, // PEM Encoded OpenSSL Key - root_ca_cert: String, // PEM Encoded OpenSSL X509 Certificate - ui: Value, // JSON Value + tor_key: Base32<[u8; 64]>, // Base32 Encoded Ed25519 Expanded Secret Key + root_ca_key: Pem>, // PEM Encoded OpenSSL Key + root_ca_cert: Pem, // PEM Encoded OpenSSL X509 Certificate + ui: Value, // JSON Value } impl OsBackupV0 { fn project(self) -> Result { @@ -74,9 +78,13 @@ impl OsBackupV0 { server_id: generate_id(), hostname: generate_hostname(), password: Default::default(), - key: Key::new(None), - root_ca_key: PKey::private_key_from_pem(self.root_ca_key.as_bytes())?, - root_ca_cert: X509::from_pem(self.root_ca_cert.as_bytes())?, + root_ca_key: self.root_ca_key.0, + root_ca_cert: self.root_ca_cert.0, + ssh_key: ssh_key::PrivateKey::random( + &mut rand::thread_rng(), + ssh_key::Algorithm::Ed25519, + )?, + tor_key: TorSecretKeyV3::from(self.tor_key.0), }, ui: self.ui, }) @@ -87,36 +95,67 @@ impl OsBackupV0 { #[derive(Deserialize, Serialize)] #[serde(rename = "kebab-case")] struct OsBackupV1 { - server_id: String, // uuidv4 - hostname: String, // embassy-- - net_key: Base64<[u8; 32]>, // Ed25519 Secret Key - root_ca_key: String, // PEM Encoded OpenSSL Key - root_ca_cert: String, // PEM Encoded OpenSSL X509 Certificate - ui: Value, // JSON Value - // TODO add more + server_id: String, // uuidv4 + hostname: String, // embassy-- + net_key: Base64<[u8; 32]>, // Ed25519 Secret Key + root_ca_key: Pem>, // PEM Encoded OpenSSL Key + root_ca_cert: Pem, // PEM Encoded OpenSSL X509 Certificate + ui: Value, // JSON Value } impl OsBackupV1 { - fn project(self) -> Result { - Ok(OsBackup { + fn project(self) -> OsBackup { + OsBackup { account: AccountInfo { server_id: self.server_id, hostname: Hostname(self.hostname), password: Default::default(), - key: Key::from_bytes(None, self.net_key.0), - root_ca_key: PKey::private_key_from_pem(self.root_ca_key.as_bytes())?, - root_ca_cert: X509::from_pem(self.root_ca_cert.as_bytes())?, + root_ca_key: self.root_ca_key.0, + root_ca_cert: self.root_ca_cert.0, + ssh_key: ssh_key::PrivateKey::from(Ed25519Keypair::from_seed(&self.net_key.0)), + tor_key: TorSecretKeyV3::from(ed25519_expand_key(&self.net_key.0)), }, ui: self.ui, - }) - } - fn unproject(backup: &OsBackup) -> Result { - Ok(Self { - server_id: backup.account.server_id.clone(), - hostname: backup.account.hostname.0.clone(), - net_key: Base64(backup.account.key.as_bytes()), - root_ca_key: String::from_utf8(backup.account.root_ca_key.private_key_to_pem_pkcs8()?)?, - root_ca_cert: String::from_utf8(backup.account.root_ca_cert.to_pem()?)?, - ui: backup.ui.clone(), - }) + } + } +} + +/// V2 +#[derive(Deserialize, Serialize)] +#[serde(rename = "kebab-case")] + +struct OsBackupV2 { + server_id: String, // uuidv4 + hostname: String, // - + root_ca_key: Pem>, // PEM Encoded OpenSSL Key + root_ca_cert: Pem, // PEM Encoded OpenSSL X509 Certificate + ssh_key: Pem, // PEM Encoded OpenSSH Key + tor_key: TorSecretKeyV3, // Base64 Encoded Ed25519 Expanded Secret Key + ui: Value, // JSON Value +} +impl OsBackupV2 { + fn project(self) -> OsBackup { + OsBackup { + account: AccountInfo { + server_id: self.server_id, + hostname: Hostname(self.hostname), + password: Default::default(), + root_ca_key: self.root_ca_key.0, + root_ca_cert: self.root_ca_cert.0, + ssh_key: self.ssh_key.0, + tor_key: self.tor_key, + }, + ui: self.ui, + } + } + fn unproject(backup: &OsBackup) -> Self { + Self { + server_id: backup.account.server_id.clone(), + hostname: backup.account.hostname.0.clone(), + root_ca_key: Pem(backup.account.root_ca_key.clone()), + root_ca_cert: Pem(backup.account.root_ca_cert.clone()), + ssh_key: Pem(backup.account.ssh_key.clone()), + tor_key: backup.account.tor_key.clone(), + ui: backup.ui.clone(), + } } } diff --git a/core/startos/src/backup/restore.rs b/core/startos/src/backup/restore.rs index 404c12c6b..bae7eb58a 100644 --- a/core/startos/src/backup/restore.rs +++ b/core/startos/src/backup/restore.rs @@ -5,6 +5,7 @@ use clap::Parser; use futures::{stream, StreamExt}; use models::PackageId; use openssl::x509::X509; +use patch_db::json_ptr::ROOT; use serde::{Deserialize, Serialize}; use torut::onion::OnionAddressV3; use tracing::instrument; @@ -12,6 +13,7 @@ use tracing::instrument; use super::target::BackupTargetId; use crate::backup::os::OsBackup; use crate::context::{RpcContext, SetupContext}; +use crate::db::model::Database; use crate::disk::mount::backup::BackupMountGuard; use crate::disk::mount::filesystem::ReadWrite; use crate::disk::mount::guard::{GenericMountGuard, TmpMountGuard}; @@ -42,9 +44,7 @@ pub async fn restore_packages_rpc( password, }: RestorePackageParams, ) -> Result<(), Error> { - let fs = target_id - .load(ctx.secret_store.acquire().await?.as_mut()) - .await?; + let fs = target_id.load(&ctx.db.peek().await)?; let backup_guard = BackupMountGuard::mount(TmpMountGuard::mount(&fs, ReadWrite).await?, &password).await?; @@ -95,11 +95,8 @@ pub async fn recover_full_embassy( ) .with_kind(ErrorKind::PasswordHashGeneration)?; - let secret_store = ctx.secret_store().await?; - - os_backup.account.save(&secret_store).await?; - - secret_store.close().await; + let db = ctx.db().await?; + db.put(&ROOT, &Database::init(&os_backup.account)?).await?; init(&ctx.config).await?; @@ -129,7 +126,7 @@ pub async fn recover_full_embassy( Ok(( disk_guid, os_backup.account.hostname, - os_backup.account.key.tor_address(), + os_backup.account.tor_key.public().get_onion_address(), os_backup.account.root_ca_cert, )) } diff --git a/core/startos/src/backup/target/cifs.rs b/core/startos/src/backup/target/cifs.rs index 4f3ee4827..db332e28f 100644 --- a/core/startos/src/backup/target/cifs.rs +++ b/core/startos/src/backup/target/cifs.rs @@ -1,14 +1,15 @@ +use std::collections::BTreeMap; use std::path::{Path, PathBuf}; use clap::Parser; use color_eyre::eyre::eyre; -use futures::TryStreamExt; +use imbl_value::InternedString; use rpc_toolkit::{command, from_fn_async, HandlerExt, ParentHandler}; use serde::{Deserialize, Serialize}; -use sqlx::{Executor, Postgres}; use super::{BackupTarget, BackupTargetId}; use crate::context::{CliContext, RpcContext}; +use crate::db::model::DatabaseModel; use crate::disk::mount::filesystem::cifs::Cifs; use crate::disk::mount::filesystem::ReadOnly; use crate::disk::mount::guard::{GenericMountGuard, TmpMountGuard}; @@ -16,6 +17,24 @@ use crate::disk::util::{recovery_info, EmbassyOsRecoveryInfo}; use crate::prelude::*; use crate::util::serde::KeyVal; +#[derive(Debug, Default, Deserialize, Serialize)] +pub struct CifsTargets(pub BTreeMap); +impl CifsTargets { + pub fn new() -> Self { + Self(BTreeMap::new()) + } +} +impl Map for CifsTargets { + type Key = u32; + type Value = Cifs; + fn key_str(key: &Self::Key) -> Result, Error> { + Self::key_string(key) + } + fn key_string(key: &Self::Key) -> Result { + Ok(InternedString::from_display(key)) + } +} + #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "kebab-case")] pub struct CifsBackupTarget { @@ -69,23 +88,27 @@ pub async fn add( ) -> Result, Error> { let cifs = Cifs { hostname, - path, + path: Path::new("/").join(path), username, password, }; let guard = TmpMountGuard::mount(&cifs, ReadOnly).await?; let embassy_os = recovery_info(guard.path()).await?; guard.unmount().await?; - let path_string = Path::new("/").join(&cifs.path).display().to_string(); - let id: i32 = sqlx::query!( - "INSERT INTO cifs_shares (hostname, path, username, password) VALUES ($1, $2, $3, $4) RETURNING id", - cifs.hostname, - path_string, - cifs.username, - cifs.password, - ) - .fetch_one(&ctx.secret_store) - .await?.id; + let id = ctx + .db + .mutate(|db| { + let id = db + .as_private() + .as_cifs() + .keys()? + .into_iter() + .max() + .map_or(0, |a| a + 1); + db.as_private_mut().as_cifs_mut().insert(&id, &cifs)?; + Ok(id) + }) + .await?; Ok(KeyVal { key: BackupTargetId::Cifs { id }, value: BackupTarget::Cifs(CifsBackupTarget { @@ -129,32 +152,27 @@ pub async fn update( }; let cifs = Cifs { hostname, - path, + path: Path::new("/").join(path), username, password, }; let guard = TmpMountGuard::mount(&cifs, ReadOnly).await?; let embassy_os = recovery_info(guard.path()).await?; guard.unmount().await?; - let path_string = Path::new("/").join(&cifs.path).display().to_string(); - if sqlx::query!( - "UPDATE cifs_shares SET hostname = $1, path = $2, username = $3, password = $4 WHERE id = $5", - cifs.hostname, - path_string, - cifs.username, - cifs.password, - id, - ) - .execute(&ctx.secret_store) - .await? - .rows_affected() - == 0 - { - return Err(Error::new( - eyre!("Backup Target ID {} Not Found", BackupTargetId::Cifs { id }), - ErrorKind::NotFound, - )); - }; + ctx.db + .mutate(|db| { + db.as_private_mut() + .as_cifs_mut() + .as_idx_mut(&id) + .ok_or_else(|| { + Error::new( + eyre!("Backup Target ID {} Not Found", BackupTargetId::Cifs { id }), + ErrorKind::NotFound, + ) + })? + .ser(&cifs) + }) + .await?; Ok(KeyVal { key: BackupTargetId::Cifs { id }, value: BackupTarget::Cifs(CifsBackupTarget { @@ -183,74 +201,46 @@ pub async fn remove(ctx: RpcContext, RemoveParams { id }: RemoveParams) -> Resul ErrorKind::NotFound, )); }; - if sqlx::query!("DELETE FROM cifs_shares WHERE id = $1", id) - .execute(&ctx.secret_store) - .await? - .rows_affected() - == 0 - { - return Err(Error::new( - eyre!("Backup Target ID {} Not Found", BackupTargetId::Cifs { id }), - ErrorKind::NotFound, - )); - }; + ctx.db + .mutate(|db| db.as_private_mut().as_cifs_mut().remove(&id)) + .await?; Ok(()) } -pub async fn load(secrets: &mut Ex, id: i32) -> Result -where - for<'a> &'a mut Ex: Executor<'a, Database = Postgres>, -{ - let record = sqlx::query!( - "SELECT hostname, path, username, password FROM cifs_shares WHERE id = $1", - id - ) - .fetch_one(secrets) - .await?; - - Ok(Cifs { - hostname: record.hostname, - path: PathBuf::from(record.path), - username: record.username, - password: record.password, - }) +pub fn load(db: &DatabaseModel, id: u32) -> Result { + db.as_private() + .as_cifs() + .as_idx(&id) + .ok_or_else(|| { + Error::new( + eyre!("Backup Target ID {} Not Found", id), + ErrorKind::NotFound, + ) + })? + .de() } -pub async fn list(secrets: &mut Ex) -> Result, Error> -where - for<'a> &'a mut Ex: Executor<'a, Database = Postgres>, -{ - let mut records = - sqlx::query!("SELECT id, hostname, path, username, password FROM cifs_shares") - .fetch_many(secrets); - +pub async fn list(db: &DatabaseModel) -> Result, Error> { let mut cifs = Vec::new(); - while let Some(query_result) = records.try_next().await? { - if let Some(record) = query_result.right() { - let mount_info = Cifs { - hostname: record.hostname, - path: PathBuf::from(record.path), - username: record.username, - password: record.password, - }; - let embassy_os = async { - let guard = TmpMountGuard::mount(&mount_info, ReadOnly).await?; - let embassy_os = recovery_info(guard.path()).await?; - guard.unmount().await?; - Ok::<_, Error>(embassy_os) - } - .await; - cifs.push(( - record.id, - CifsBackupTarget { - hostname: mount_info.hostname, - path: mount_info.path, - username: mount_info.username, - mountable: embassy_os.is_ok(), - embassy_os: embassy_os.ok().and_then(|a| a), - }, - )); + for (id, model) in db.as_private().as_cifs().as_entries()? { + let mount_info = model.de()?; + let embassy_os = async { + let guard = TmpMountGuard::mount(&mount_info, ReadOnly).await?; + let embassy_os = recovery_info(guard.path()).await?; + guard.unmount().await?; + Ok::<_, Error>(embassy_os) } + .await; + cifs.push(( + id, + CifsBackupTarget { + hostname: mount_info.hostname, + path: mount_info.path, + username: mount_info.username, + mountable: embassy_os.is_ok(), + embassy_os: embassy_os.ok().and_then(|a| a), + }, + )); } Ok(cifs) diff --git a/core/startos/src/backup/target/mod.rs b/core/startos/src/backup/target/mod.rs index 473b2865d..72dd45832 100644 --- a/core/startos/src/backup/target/mod.rs +++ b/core/startos/src/backup/target/mod.rs @@ -11,12 +11,12 @@ use models::PackageId; use rpc_toolkit::{command, from_fn_async, AnyContext, HandlerExt, ParentHandler}; use serde::{Deserialize, Serialize}; use sha2::Sha256; -use sqlx::{Executor, Postgres}; use tokio::sync::Mutex; use tracing::instrument; use self::cifs::CifsBackupTarget; use crate::context::{CliContext, RpcContext}; +use crate::db::model::DatabaseModel; use crate::disk::mount::backup::BackupMountGuard; use crate::disk::mount::filesystem::block_dev::BlockDev; use crate::disk::mount::filesystem::cifs::Cifs; @@ -49,18 +49,15 @@ pub enum BackupTarget { #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] pub enum BackupTargetId { Disk { logicalname: PathBuf }, - Cifs { id: i32 }, + Cifs { id: u32 }, } impl BackupTargetId { - pub async fn load(self, secrets: &mut Ex) -> Result - where - for<'a> &'a mut Ex: Executor<'a, Database = Postgres>, - { + pub fn load(self, db: &DatabaseModel) -> Result { Ok(match self { BackupTargetId::Disk { logicalname } => { BackupTargetFS::Disk(BlockDev::new(logicalname)) } - BackupTargetId::Cifs { id } => BackupTargetFS::Cifs(cifs::load(secrets, id).await?), + BackupTargetId::Cifs { id } => BackupTargetFS::Cifs(cifs::load(db, id)?), }) } } @@ -161,10 +158,10 @@ pub fn target() -> ParentHandler { // #[command(display(display_serializable))] pub async fn list(ctx: RpcContext) -> Result, Error> { - let mut sql_handle = ctx.secret_store.acquire().await?; + let peek = ctx.db.peek().await; let (disks_res, cifs) = tokio::try_join!( crate::disk::util::list(&ctx.os_partitions), - cifs::list(sql_handle.as_mut()), + cifs::list(&peek), )?; Ok(disks_res .into_iter() @@ -262,13 +259,7 @@ pub async fn info( }: InfoParams, ) -> Result { let guard = BackupMountGuard::mount( - TmpMountGuard::mount( - &target_id - .load(ctx.secret_store.acquire().await?.as_mut()) - .await?, - ReadWrite, - ) - .await?, + TmpMountGuard::mount(&target_id.load(&ctx.db.peek().await)?, ReadWrite).await?, &password, ) .await?; @@ -308,14 +299,7 @@ pub async fn mount( } let guard = BackupMountGuard::mount( - TmpMountGuard::mount( - &target_id - .clone() - .load(ctx.secret_store.acquire().await?.as_mut()) - .await?, - ReadWrite, - ) - .await?, + TmpMountGuard::mount(&target_id.clone().load(&ctx.db.peek().await)?, ReadWrite).await?, &password, ) .await?; diff --git a/core/startos/src/context/config.rs b/core/startos/src/context/config.rs index fc9cfb790..55065e816 100644 --- a/core/startos/src/context/config.rs +++ b/core/startos/src/context/config.rs @@ -3,15 +3,12 @@ use std::net::SocketAddr; use std::path::{Path, PathBuf}; use clap::Parser; -use patch_db::json_ptr::JsonPointer; use reqwest::Url; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use sqlx::postgres::PgConnectOptions; use sqlx::PgPool; -use crate::account::AccountInfo; -use crate::db::model::Database; use crate::disk::OsPartitionInfo; use crate::init::init_postgres; use crate::prelude::*; @@ -149,15 +146,12 @@ impl ServerConfig { .as_deref() .unwrap_or_else(|| Path::new("/embassy-data")) } - pub async fn db(&self, account: &AccountInfo) -> Result { + pub async fn db(&self) -> Result { let db_path = self.datadir().join("main").join("embassy.db"); let db = PatchDb::open(&db_path) .await .with_ctx(|_| (crate::ErrorKind::Filesystem, db_path.display().to_string()))?; - if !db.exists(&::default()).await { - db.put(&::default(), &Database::init(account)) - .await?; - } + Ok(db) } #[instrument(skip_all)] diff --git a/core/startos/src/context/rpc.rs b/core/startos/src/context/rpc.rs index 905132071..0f11bf63f 100644 --- a/core/startos/src/context/rpc.rs +++ b/core/startos/src/context/rpc.rs @@ -11,7 +11,6 @@ use josekit::jwk::Jwk; use patch_db::PatchDb; use reqwest::{Client, Proxy}; use rpc_toolkit::Context; -use sqlx::PgPool; use tokio::sync::{broadcast, oneshot, Mutex, RwLock}; use tokio::time::Instant; use tracing::instrument; @@ -28,14 +27,11 @@ use crate::init::check_time_is_synchronized; use crate::lxc::{LxcContainer, LxcManager}; use crate::middleware::auth::HashSessionToken; use crate::net::net_controller::NetController; -use crate::net::ssl::{root_ca_start_time, SslManager}; use crate::net::utils::find_eth_iface; use crate::net::wifi::WpaCli; -use crate::notifications::NotificationManager; use crate::prelude::*; use crate::service::ServiceMap; use crate::shutdown::Shutdown; -use crate::status::MainStatus; use crate::system::get_mem_info; use crate::util::lshw::{lshw, LshwDevice}; @@ -47,14 +43,12 @@ pub struct RpcContextSeed { pub datadir: PathBuf, pub disk_guid: Arc, pub db: PatchDb, - pub secret_store: PgPool, pub account: RwLock, pub net_controller: Arc, pub services: ServiceMap, pub metrics_cache: RwLock>, pub shutdown: broadcast::Sender>, pub tor_socks: SocketAddr, - pub notification_manager: NotificationManager, pub lxc_manager: Arc, pub open_authed_websockets: Mutex>>>, pub rpc_stream_continuations: Mutex>, @@ -86,13 +80,14 @@ impl RpcContext { 9050, ))); let (shutdown, _) = tokio::sync::broadcast::channel(1); - let secret_store = config.secret_store().await?; - tracing::info!("Opened Pg DB"); - let account = AccountInfo::load(&secret_store).await?; - let db = config.db(&account).await?; + + let db = config.db().await?; + let peek = db.peek().await; + let account = AccountInfo::load(&peek)?; tracing::info!("Opened PatchDB"); let net_controller = Arc::new( NetController::init( + db.clone(), config .tor_control .unwrap_or(SocketAddr::from(([127, 0, 0, 1], 9051))), @@ -101,16 +96,14 @@ impl RpcContext { .dns_bind .as_deref() .unwrap_or(&[SocketAddr::from(([127, 0, 0, 1], 53))]), - SslManager::new(&account, root_ca_start_time().await?)?, &account.hostname, - &account.key, + account.tor_key.clone(), ) .await?, ); tracing::info!("Initialized Net Controller"); let services = ServiceMap::default(); let metrics_cache = RwLock::>::new(None); - let notification_manager = NotificationManager::new(secret_store.clone()); tracing::info!("Initialized Notification Manager"); let tor_proxy_url = format!("socks5h://{tor_proxy}"); let devices = lshw().await?; @@ -157,14 +150,12 @@ impl RpcContext { }, disk_guid, db, - secret_store, account: RwLock::new(account), net_controller, services, metrics_cache, shutdown, tor_socks: tor_proxy, - notification_manager, lxc_manager: Arc::new(LxcManager::new()), open_authed_websockets: Mutex::new(BTreeMap::new()), rpc_stream_continuations: Mutex::new(BTreeMap::new()), @@ -208,7 +199,6 @@ impl RpcContext { #[instrument(skip_all)] pub async fn shutdown(self) -> Result<(), Error> { self.services.shutdown_all().await?; - self.secret_store.close().await; self.is_closed.store(true, Ordering::SeqCst); tracing::info!("RPC Context is shutdown"); // TODO: shutdown http servers diff --git a/core/startos/src/context/setup.rs b/core/startos/src/context/setup.rs index aeeca2920..933aa155c 100644 --- a/core/startos/src/context/setup.rs +++ b/core/startos/src/context/setup.rs @@ -3,7 +3,6 @@ use std::path::PathBuf; use std::sync::Arc; use josekit::jwk::Jwk; -use patch_db::json_ptr::JsonPointer; use patch_db::PatchDb; use rpc_toolkit::yajrc::RpcError; use rpc_toolkit::Context; @@ -14,9 +13,7 @@ use tokio::sync::broadcast::Sender; use tokio::sync::RwLock; use tracing::instrument; -use crate::account::AccountInfo; use crate::context::config::ServerConfig; -use crate::db::model::Database; use crate::disk::OsPartitionInfo; use crate::init::init_postgres; use crate::prelude::*; @@ -81,15 +78,11 @@ impl SetupContext { }))) } #[instrument(skip_all)] - pub async fn db(&self, account: &AccountInfo) -> Result { + pub async fn db(&self) -> Result { let db_path = self.datadir.join("main").join("embassy.db"); let db = PatchDb::open(&db_path) .await .with_ctx(|_| (crate::ErrorKind::Filesystem, db_path.display().to_string()))?; - if !db.exists(&::default()).await { - db.put(&::default(), &Database::init(account)) - .await?; - } Ok(db) } #[instrument(skip_all)] diff --git a/core/startos/src/control.rs b/core/startos/src/control.rs index 893aeee2b..26b227cd2 100644 --- a/core/startos/src/control.rs +++ b/core/startos/src/control.rs @@ -24,7 +24,7 @@ pub async fn start(ctx: RpcContext, ControlParams { id }: ControlParams) -> Resu .as_ref() .or_not_found(lazy_format!("Manager for {id}"))? .start() - .await; + .await?; Ok(()) } @@ -37,7 +37,7 @@ pub async fn stop(ctx: RpcContext, ControlParams { id }: ControlParams) -> Resul .as_ref() .ok_or_else(|| Error::new(eyre!("Manager not found"), crate::ErrorKind::InvalidRequest))? .stop() - .await; + .await?; Ok(()) } @@ -49,7 +49,7 @@ pub async fn restart(ctx: RpcContext, ControlParams { id }: ControlParams) -> Re .as_ref() .ok_or_else(|| Error::new(eyre!("Manager not found"), crate::ErrorKind::InvalidRequest))? .restart() - .await; + .await?; Ok(()) } diff --git a/core/startos/src/db/model.rs b/core/startos/src/db/model.rs index 571573a54..19cfa37a1 100644 --- a/core/startos/src/db/model.rs +++ b/core/startos/src/db/model.rs @@ -13,30 +13,46 @@ use patch_db::json_ptr::JsonPointer; use patch_db::{HasModel, Value}; use reqwest::Url; use serde::{Deserialize, Serialize}; -use ssh_key::public::Ed25519PublicKey; +use torut::onion::OnionAddressV3; use crate::account::AccountInfo; +use crate::auth::Sessions; +use crate::backup::target::cifs::CifsTargets; +use crate::net::forward::AvailablePorts; +use crate::net::host::HostInfo; +use crate::net::keys::KeyStore; use crate::net::utils::{get_iface_ipv4_addr, get_iface_ipv6_addr}; +use crate::notifications::Notifications; use crate::prelude::*; use crate::progress::FullProgress; use crate::s9pk::manifest::Manifest; +use crate::ssh::SshKeys; use crate::status::Status; use crate::util::cpupower::Governor; +use crate::util::serde::Pem; use crate::util::Version; use crate::version::{Current, VersionT}; use crate::{ARCH, PLATFORM}; +fn get_arch() -> InternedString { + (*ARCH).into() +} + +fn get_platform() -> InternedString { + (&*PLATFORM).into() +} + #[derive(Debug, Deserialize, Serialize, HasModel)] #[serde(rename_all = "kebab-case")] #[model = "Model"] pub struct Database { pub public: Public, - pub private: (), // TODO + pub private: Private, } impl Database { - pub fn init(account: &AccountInfo) -> Self { + pub fn init(account: &AccountInfo) -> Result { let lan_address = account.hostname.lan_address().parse().unwrap(); - Database { + Ok(Database { public: Public { server_info: ServerInfo { arch: get_arch(), @@ -48,9 +64,13 @@ impl Database { last_wifi_region: None, eos_version_compat: Current::new().compat().clone(), lan_address, - tor_address: format!("https://{}", account.key.tor_address()) - .parse() - .unwrap(), + onion_address: account.tor_key.public().get_onion_address(), + tor_address: format!( + "https://{}", + account.tor_key.public().get_onion_address() + ) + .parse() + .unwrap(), ip_info: BTreeMap::new(), status_info: ServerStatus { backup_progress: None, @@ -70,11 +90,9 @@ impl Database { clearnet: Vec::new(), }, password_hash: account.password.clone(), - pubkey: ssh_key::PublicKey::from(Ed25519PublicKey::from( - &account.key.ssh_key(), - )) - .to_openssh() - .unwrap(), + pubkey: ssh_key::PublicKey::from(&account.ssh_key) + .to_openssh() + .unwrap(), ca_fingerprint: account .root_ca_cert .digest(MessageDigest::sha256()) @@ -93,11 +111,22 @@ impl Database { ))) .unwrap(), }, - private: (), // TODO - } + private: Private { + key_store: KeyStore::new(account)?, + password: account.password.clone(), + ssh_privkey: Pem(account.ssh_key.clone()), + ssh_pubkeys: SshKeys::new(), + available_ports: AvailablePorts::new(), + sessions: Sessions::new(), + notifications: Notifications::new(), + cifs: CifsTargets::new(), + }, // TODO + }) } } +pub type DatabaseModel = Model; + #[derive(Debug, Deserialize, Serialize, HasModel)] #[serde(rename_all = "kebab-case")] #[model = "Model"] @@ -108,14 +137,18 @@ pub struct Public { pub ui: Value, } -pub type DatabaseModel = Model; - -fn get_arch() -> InternedString { - (*ARCH).into() -} - -fn get_platform() -> InternedString { - (&*PLATFORM).into() +#[derive(Debug, Deserialize, Serialize, HasModel)] +#[serde(rename_all = "kebab-case")] +#[model = "Model"] +pub struct Private { + pub key_store: KeyStore, + pub password: String, // argon2 hash + pub ssh_privkey: Pem, + pub ssh_pubkeys: SshKeys, + pub available_ports: AvailablePorts, + pub sessions: Sessions, + pub notifications: Notifications, + pub cifs: CifsTargets, } #[derive(Debug, Deserialize, Serialize, HasModel)] @@ -134,6 +167,8 @@ pub struct ServerInfo { pub last_wifi_region: Option, pub eos_version_compat: VersionRange, pub lan_address: Url, + pub onion_address: OnionAddressV3, + /// for backwards compatibility pub tor_address: Url, pub ip_info: BTreeMap, #[serde(default)] @@ -229,6 +264,12 @@ pub struct AllPackageData(pub BTreeMap); impl Map for AllPackageData { type Key = PackageId; type Value = PackageDataEntry; + fn key_str(key: &Self::Key) -> Result, Error> { + Ok(key) + } + fn key_string(key: &Self::Key) -> Result { + Ok(key.clone().into()) + } } #[derive(Debug, Deserialize, Serialize, HasModel)] @@ -471,6 +512,7 @@ pub struct InstalledPackageInfo { pub current_dependents: CurrentDependents, pub current_dependencies: CurrentDependencies, pub interface_addresses: InterfaceAddressMap, + pub hosts: HostInfo, pub store: Value, pub store_exposed_ui: Vec, pub store_exposed_dependents: Vec, @@ -512,6 +554,12 @@ impl CurrentDependents { impl Map for CurrentDependents { type Key = PackageId; type Value = CurrentDependencyInfo; + fn key_str(key: &Self::Key) -> Result, Error> { + Ok(key) + } + fn key_string(key: &Self::Key) -> Result { + Ok(key.clone().into()) + } } #[derive(Debug, Clone, Default, Deserialize, Serialize)] pub struct CurrentDependencies(pub BTreeMap); @@ -529,6 +577,12 @@ impl CurrentDependencies { impl Map for CurrentDependencies { type Key = PackageId; type Value = CurrentDependencyInfo; + fn key_str(key: &Self::Key) -> Result, Error> { + Ok(key) + } + fn key_string(key: &Self::Key) -> Result { + Ok(key.clone().into()) + } } #[derive(Debug, Deserialize, Serialize, HasModel)] @@ -552,6 +606,12 @@ pub struct InterfaceAddressMap(pub BTreeMap); impl Map for InterfaceAddressMap { type Key = HostId; type Value = InterfaceAddresses; + fn key_str(key: &Self::Key) -> Result, Error> { + Ok(key) + } + fn key_string(key: &Self::Key) -> Result { + Ok(key.clone().into()) + } } #[derive(Debug, Deserialize, Serialize, HasModel)] diff --git a/core/startos/src/db/prelude.rs b/core/startos/src/db/prelude.rs index 14f5e21eb..911b2d389 100644 --- a/core/startos/src/db/prelude.rs +++ b/core/startos/src/db/prelude.rs @@ -1,13 +1,15 @@ use std::collections::BTreeMap; use std::marker::PhantomData; use std::panic::UnwindSafe; +use std::str::FromStr; +use chrono::{DateTime, Utc}; pub use imbl_value::Value; use patch_db::json_ptr::ROOT; use patch_db::value::InternedString; pub use patch_db::{HasModel, PatchDb}; use serde::de::DeserializeOwned; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use crate::db::model::DatabaseModel; use crate::prelude::*; @@ -92,12 +94,37 @@ impl Model { } } +impl Serialize for Model { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.value.serialize(serializer) + } +} + +impl<'de, T: Serialize + Deserialize<'de>> Deserialize<'de> for Model { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + Self::new(&T::deserialize(deserializer)?).map_err(D::Error::custom) + } +} + impl Model { pub fn replace(&mut self, value: &T) -> Result { let orig = self.de()?; self.ser(value)?; Ok(orig) } + pub fn mutate(&mut self, f: impl FnOnce(&mut T) -> Result) -> Result { + let mut orig = self.de()?; + let res = f(&mut orig)?; + self.ser(&orig)?; + Ok(res) + } } impl Clone for Model { fn clone(&self) -> Self { @@ -181,20 +208,38 @@ impl Model> { pub trait Map: DeserializeOwned + Serialize { type Key; type Value; + fn key_str(key: &Self::Key) -> Result, Error>; + fn key_string(key: &Self::Key) -> Result { + Ok(InternedString::intern(Self::key_str(key)?.as_ref())) + } } impl Map for BTreeMap +where + A: serde::Serialize + serde::de::DeserializeOwned + Ord + AsRef, + B: serde::Serialize + serde::de::DeserializeOwned, +{ + type Key = A; + type Value = B; + fn key_str(key: &Self::Key) -> Result, Error> { + Ok(key.as_ref()) + } +} + +impl Map for BTreeMap, B> where A: serde::Serialize + serde::de::DeserializeOwned + Ord, B: serde::Serialize + serde::de::DeserializeOwned, { type Key = A; type Value = B; + fn key_str(key: &Self::Key) -> Result, Error> { + serde_json::to_string(key).with_kind(ErrorKind::Serialization) + } } impl Model where - T::Key: AsRef, T::Value: Serialize, { pub fn insert(&mut self, key: &T::Key, value: &T::Value) -> Result<(), Error> { @@ -202,7 +247,7 @@ where let v = patch_db::value::to_value(value)?; match &mut self.value { Value::Object(o) => { - o.insert(InternedString::intern(key.as_ref()), v); + o.insert(T::key_string(key)?, v); Ok(()) } v => Err(patch_db::value::Error { @@ -212,13 +257,40 @@ where .into()), } } + pub fn upsert(&mut self, key: &T::Key, value: F) -> Result<&mut Model, Error> + where + F: FnOnce() -> D, + D: AsRef, + { + use serde::ser::Error; + match &mut self.value { + Value::Object(o) => { + use patch_db::ModelExt; + let s = T::key_str(key)?; + let exists = o.contains_key(s.as_ref()); + let res = self.transmute_mut(|v| { + use patch_db::value::index::Index; + s.as_ref().index_or_insert(v) + }); + if !exists { + res.ser(value().as_ref())?; + } + Ok(res) + } + v => Err(patch_db::value::Error { + source: patch_db::value::ErrorSource::custom(format!("expected object found {v}")), + kind: patch_db::value::ErrorKind::Serialization, + } + .into()), + } + } pub fn insert_model(&mut self, key: &T::Key, value: Model) -> Result<(), Error> { use patch_db::ModelExt; use serde::ser::Error; let v = value.into_value(); match &mut self.value { Value::Object(o) => { - o.insert(InternedString::intern(key.as_ref()), v); + o.insert(T::key_string(key)?, v); Ok(()) } v => Err(patch_db::value::Error { @@ -232,25 +304,16 @@ where impl Model where - T::Key: DeserializeOwned + Ord + Clone, + T::Key: FromStr + Ord + Clone, + Error: From<::Err>, { pub fn keys(&self) -> Result, Error> { use serde::de::Error; - use serde::Deserialize; match &self.value { Value::Object(o) => o .keys() .cloned() - .map(|k| { - T::Key::deserialize(patch_db::value::de::InternedStringDeserializer::from(k)) - .map_err(|e| { - patch_db::value::Error { - kind: patch_db::value::ErrorKind::Deserialization, - source: e, - } - .into() - }) - }) + .map(|k| Ok(T::Key::from_str(&*k)?)) .collect(), v => Err(patch_db::value::Error { source: patch_db::value::ErrorSource::custom(format!("expected object found {v}")), @@ -263,19 +326,10 @@ where pub fn into_entries(self) -> Result)>, Error> { use patch_db::ModelExt; use serde::de::Error; - use serde::Deserialize; match self.value { Value::Object(o) => o .into_iter() - .map(|(k, v)| { - Ok(( - T::Key::deserialize(patch_db::value::de::InternedStringDeserializer::from( - k, - )) - .with_kind(ErrorKind::Deserialization)?, - Model::from_value(v), - )) - }) + .map(|(k, v)| Ok((T::Key::from_str(&*k)?, Model::from_value(v)))) .collect(), v => Err(patch_db::value::Error { source: patch_db::value::ErrorSource::custom(format!("expected object found {v}")), @@ -287,19 +341,10 @@ where pub fn as_entries(&self) -> Result)>, Error> { use patch_db::ModelExt; use serde::de::Error; - use serde::Deserialize; match &self.value { Value::Object(o) => o .iter() - .map(|(k, v)| { - Ok(( - T::Key::deserialize(patch_db::value::de::InternedStringDeserializer::from( - k.clone(), - )) - .with_kind(ErrorKind::Deserialization)?, - Model::value_as(v), - )) - }) + .map(|(k, v)| Ok((T::Key::from_str(&**k)?, Model::value_as(v)))) .collect(), v => Err(patch_db::value::Error { source: patch_db::value::ErrorSource::custom(format!("expected object found {v}")), @@ -311,19 +356,10 @@ where pub fn as_entries_mut(&mut self) -> Result)>, Error> { use patch_db::ModelExt; use serde::de::Error; - use serde::Deserialize; match &mut self.value { Value::Object(o) => o .iter_mut() - .map(|(k, v)| { - Ok(( - T::Key::deserialize(patch_db::value::de::InternedStringDeserializer::from( - k.clone(), - )) - .with_kind(ErrorKind::Deserialization)?, - Model::value_as_mut(v), - )) - }) + .map(|(k, v)| Ok((T::Key::from_str(&**k)?, Model::value_as_mut(v)))) .collect(), v => Err(patch_db::value::Error { source: patch_db::value::ErrorSource::custom(format!("expected object found {v}")), @@ -333,36 +369,36 @@ where } } } -impl Model -where - T::Key: AsRef, -{ +impl Model { pub fn into_idx(self, key: &T::Key) -> Option> { use patch_db::ModelExt; + let s = T::key_str(key).ok()?; match &self.value { - Value::Object(o) if o.contains_key(key.as_ref()) => Some(self.transmute(|v| { + Value::Object(o) if o.contains_key(s.as_ref()) => Some(self.transmute(|v| { use patch_db::value::index::Index; - key.as_ref().index_into_owned(v).unwrap() + s.as_ref().index_into_owned(v).unwrap() })), _ => None, } } pub fn as_idx<'a>(&'a self, key: &T::Key) -> Option<&'a Model> { use patch_db::ModelExt; + let s = T::key_str(key).ok()?; match &self.value { - Value::Object(o) if o.contains_key(key.as_ref()) => Some(self.transmute_ref(|v| { + Value::Object(o) if o.contains_key(s.as_ref()) => Some(self.transmute_ref(|v| { use patch_db::value::index::Index; - key.as_ref().index_into(v).unwrap() + s.as_ref().index_into(v).unwrap() })), _ => None, } } pub fn as_idx_mut<'a>(&'a mut self, key: &T::Key) -> Option<&'a mut Model> { use patch_db::ModelExt; + let s = T::key_str(key).ok()?; match &mut self.value { - Value::Object(o) if o.contains_key(key.as_ref()) => Some(self.transmute_mut(|v| { + Value::Object(o) if o.contains_key(s.as_ref()) => Some(self.transmute_mut(|v| { use patch_db::value::index::Index; - key.as_ref().index_or_insert(v) + s.as_ref().index_or_insert(v) })), _ => None, } @@ -371,7 +407,7 @@ where use serde::ser::Error; match &mut self.value { Value::Object(o) => { - let v = o.remove(key.as_ref()); + let v = o.remove(T::key_str(key)?.as_ref()); Ok(v.map(patch_db::ModelExt::from_value)) } v => Err(patch_db::value::Error { @@ -382,3 +418,90 @@ where } } } + +#[repr(transparent)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct JsonKey(pub T); +impl From for JsonKey { + fn from(value: T) -> Self { + Self::new(value) + } +} +impl JsonKey { + pub fn new(value: T) -> Self { + Self(value) + } + pub fn unwrap(self) -> T { + self.0 + } + pub fn new_ref(value: &T) -> &Self { + unsafe { std::mem::transmute(value) } + } + pub fn new_mut(value: &mut T) -> &mut Self { + unsafe { std::mem::transmute(value) } + } +} +impl std::ops::Deref for JsonKey { + type Target = T; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl std::ops::DerefMut for JsonKey { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} +impl Serialize for JsonKey { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::Error; + serde_json::to_string(&self.0) + .map_err(S::Error::custom)? + .serialize(serializer) + } +} +// { "foo": "bar" } -> "{ \"foo\": \"bar\" }" +impl<'de, T: Serialize + DeserializeOwned> Deserialize<'de> for JsonKey { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::Error; + let string = String::deserialize(deserializer)?; + Ok(Self( + serde_json::from_str(&string).map_err(D::Error::custom)?, + )) + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct WithTimeData { + pub created_at: DateTime, + pub updated_at: DateTime, + pub value: T, +} +impl WithTimeData { + pub fn new(value: T) -> Self { + let now = Utc::now(); + Self { + created_at: now, + updated_at: now, + value, + } + } +} +impl std::ops::Deref for WithTimeData { + type Target = T; + fn deref(&self) -> &Self::Target { + &self.value + } +} +impl std::ops::DerefMut for WithTimeData { + fn deref_mut(&mut self) -> &mut Self::Target { + self.updated_at = Utc::now(); + &mut self.value + } +} diff --git a/core/startos/src/dependencies.rs b/core/startos/src/dependencies.rs index 6ebe7afed..4e96a3db3 100644 --- a/core/startos/src/dependencies.rs +++ b/core/startos/src/dependencies.rs @@ -9,13 +9,11 @@ use serde::{Deserialize, Serialize}; use tracing::instrument; use crate::config::{Config, ConfigSpec, ConfigureContext}; -use crate::context::{CliContext, RpcContext}; +use crate::context::RpcContext; use crate::db::model::{CurrentDependencies, Database}; use crate::prelude::*; use crate::s9pk::manifest::Manifest; use crate::status::DependencyConfigErrors; -use crate::util::serde::HandlerExtSerde; -use crate::util::Version; use crate::Error; pub fn dependency() -> ParentHandler { @@ -28,6 +26,12 @@ pub struct Dependencies(pub BTreeMap); impl Map for Dependencies { type Key = PackageId; type Value = DepInfo; + fn key_str(key: &Self::Key) -> Result, Error> { + Ok(key) + } + fn key_string(key: &Self::Key) -> Result { + Ok(key.clone().into()) + } } #[derive(Clone, Debug, Deserialize, Serialize)] diff --git a/core/startos/src/firmware.rs b/core/startos/src/firmware.rs index ed4a6577a..86f576cce 100644 --- a/core/startos/src/firmware.rs +++ b/core/startos/src/firmware.rs @@ -2,7 +2,6 @@ use std::collections::BTreeSet; use std::path::Path; use async_compression::tokio::bufread::GzipDecoder; -use clap::Parser; use serde::{Deserialize, Serialize}; use tokio::fs::File; use tokio::io::BufReader; diff --git a/core/startos/src/init.rs b/core/startos/src/init.rs index fab80ab09..68a57fca9 100644 --- a/core/startos/src/init.rs +++ b/core/startos/src/init.rs @@ -6,7 +6,6 @@ use std::time::{Duration, SystemTime}; use color_eyre::eyre::eyre; use models::ResultExt; use rand::random; -use sqlx::{Pool, Postgres}; use tokio::process::Command; use tracing::instrument; @@ -179,7 +178,6 @@ pub async fn init_postgres(datadir: impl AsRef) -> Result<(), Error> { } pub struct InitResult { - pub secret_store: Pool, pub db: patch_db::PatchDb, } @@ -208,16 +206,19 @@ pub async fn init(cfg: &ServerConfig) -> Result { .await?; } - let secret_store = cfg.secret_store().await?; - tracing::info!("Opened Postgres"); + let db = cfg.db().await?; + let peek = db.peek().await; + tracing::info!("Opened PatchDB"); - crate::ssh::sync_keys_from_db(&secret_store, "/home/start9/.ssh/authorized_keys").await?; + crate::ssh::sync_keys( + &peek.as_private().as_ssh_pubkeys().de()?, + "/home/start9/.ssh/authorized_keys", + ) + .await?; tracing::info!("Synced SSH Keys"); - let account = AccountInfo::load(&secret_store).await?; - let db = cfg.db(&account).await?; - tracing::info!("Opened PatchDB"); - let peek = db.peek().await; + let account = AccountInfo::load(&peek)?; + let mut server_info = peek.as_public().as_server_info().de()?; // write to ca cert store @@ -348,7 +349,7 @@ pub async fn init(cfg: &ServerConfig) -> Result { }) .await?; - crate::version::init(&db, &secret_store).await?; + crate::version::init(&db).await?; db.mutate(|d| { let model = d.de()?; @@ -366,5 +367,5 @@ pub async fn init(cfg: &ServerConfig) -> Result { tracing::info!("System initialized."); - Ok(InitResult { secret_store, db }) + Ok(InitResult { db }) } diff --git a/core/startos/src/install/mod.rs b/core/startos/src/install/mod.rs index ac00a750b..6ea4a7129 100644 --- a/core/startos/src/install/mod.rs +++ b/core/startos/src/install/mod.rs @@ -1,4 +1,3 @@ -use std::io::SeekFrom; use std::path::PathBuf; use std::time::Duration; @@ -14,7 +13,6 @@ use rpc_toolkit::yajrc::RpcError; use rpc_toolkit::CallRemote; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use tokio::io::{AsyncReadExt, AsyncSeekExt}; use tokio::sync::oneshot; use tracing::instrument; @@ -28,8 +26,6 @@ use crate::prelude::*; use crate::progress::{FullProgress, PhasedProgressBar}; use crate::s9pk::manifest::PackageId; use crate::s9pk::merkle_archive::source::http::HttpSource; -use crate::s9pk::v1::reader::S9pkReader; -use crate::s9pk::v2::compat::{self, MAGIC_AND_VERSION}; use crate::s9pk::S9pk; use crate::upload::upload; use crate::util::clap::FromStrParser; diff --git a/core/startos/src/logs.rs b/core/startos/src/logs.rs index 0b7ef3c67..0431274ac 100644 --- a/core/startos/src/logs.rs +++ b/core/startos/src/logs.rs @@ -7,7 +7,7 @@ use chrono::{DateTime, Utc}; use clap::Parser; use color_eyre::eyre::eyre; use futures::stream::BoxStream; -use futures::{FutureExt, SinkExt, Stream, StreamExt, TryStreamExt}; +use futures::{FutureExt, Stream, StreamExt, TryStreamExt}; use models::PackageId; use rpc_toolkit::yajrc::RpcError; use rpc_toolkit::{command, from_fn_async, CallRemote, Empty, HandlerExt, ParentHandler}; diff --git a/core/startos/src/lxc/mod.rs b/core/startos/src/lxc/mod.rs index 136a8423b..374952f98 100644 --- a/core/startos/src/lxc/mod.rs +++ b/core/startos/src/lxc/mod.rs @@ -1,4 +1,5 @@ use std::collections::BTreeSet; +use std::net::Ipv4Addr; use std::ops::Deref; use std::path::Path; use std::sync::{Arc, Weak}; @@ -109,6 +110,7 @@ impl LxcManager { pub struct LxcContainer { manager: Weak, rootfs: OverlayGuard, + ip: Ipv4Addr, guid: Arc, rpc_bind: TmpMountGuard, config: LxcConfig, @@ -169,9 +171,20 @@ impl LxcContainer { .arg(&*guid) .invoke(ErrorKind::Lxc) .await?; + let ip = String::from_utf8( + Command::new("lxc-info") + .arg("--name") + .arg(&*guid) + .arg("-iH") + .invoke(ErrorKind::Docker) + .await?, + )? + .trim() + .parse()?; Ok(Self { manager: Arc::downgrade(manager), rootfs, + ip, guid: Arc::new(guid), rpc_bind, config, @@ -183,6 +196,10 @@ impl LxcContainer { self.rootfs.path() } + pub fn ip(&self) -> Ipv4Addr { + self.ip + } + pub fn rpc_dir(&self) -> &Path { self.rpc_bind.path() } diff --git a/core/startos/src/middleware/auth.rs b/core/startos/src/middleware/auth.rs index 9260d7fa2..2eddcd3ad 100644 --- a/core/startos/src/middleware/auth.rs +++ b/core/startos/src/middleware/auth.rs @@ -1,14 +1,17 @@ use std::borrow::Borrow; +use std::collections::BTreeSet; +use std::ops::Deref; use std::sync::Arc; use std::time::{Duration, Instant}; use axum::extract::Request; use axum::response::Response; use basic_cookies::Cookie; +use chrono::Utc; use color_eyre::eyre::eyre; use digest::Digest; use helpers::const_true; -use http::header::COOKIE; +use http::header::{COOKIE, USER_AGENT}; use http::HeaderValue; use imbl_value::InternedString; use rpc_toolkit::yajrc::INTERNAL_ERROR; @@ -38,24 +41,36 @@ pub struct HasLoggedOutSessions(()); impl HasLoggedOutSessions { pub async fn new( - logged_out_sessions: impl IntoIterator, + sessions: impl IntoIterator, ctx: &RpcContext, ) -> Result { - let mut open_authed_websockets = ctx.open_authed_websockets.lock().await; - let mut sqlx_conn = ctx.secret_store.acquire().await?; - for session in logged_out_sessions { - let session = session.as_logout_session_id(); - let session = &*session; - sqlx::query!( - "UPDATE session SET logged_out = CURRENT_TIMESTAMP WHERE id = $1", - session - ) - .execute(sqlx_conn.as_mut()) + let to_log_out: BTreeSet<_> = sessions + .into_iter() + .map(|s| s.as_logout_session_id()) + .collect(); + ctx.open_authed_websockets + .lock() + .await + .retain(|session, sockets| { + if to_log_out.contains(session.hashed()) { + for socket in std::mem::take(sockets) { + let _ = socket.send(()); + } + false + } else { + true + } + }); + ctx.db + .mutate(|db| { + let sessions = db.as_private_mut().as_sessions_mut(); + for sid in &to_log_out { + sessions.remove(sid)?; + } + + Ok(()) + }) .await?; - for socket in open_authed_websockets.remove(session).unwrap_or_default() { - let _ = socket.send(()); - } - } Ok(HasLoggedOutSessions(())) } } @@ -105,15 +120,20 @@ impl HasValidSession { ctx: &RpcContext, ) -> Result { let session_hash = session_token.hashed(); - let session = sqlx::query!("UPDATE session SET last_active = CURRENT_TIMESTAMP WHERE id = $1 AND logged_out IS NULL OR logged_out > CURRENT_TIMESTAMP", session_hash) - .execute(ctx.secret_store.acquire().await?.as_mut()) + ctx.db + .mutate(|db| { + db.as_private_mut() + .as_sessions_mut() + .as_idx_mut(session_hash) + .ok_or_else(|| { + Error::new(eyre!("UNAUTHORIZED"), crate::ErrorKind::Authorization) + })? + .mutate(|s| { + s.last_active = Utc::now(); + Ok(()) + }) + }) .await?; - if session.rows_affected() == 0 { - return Err(Error::new( - eyre!("UNAUTHORIZED"), - crate::ErrorKind::Authorization, - )); - } Ok(Self(SessionType::Session(session_token))) } @@ -181,8 +201,8 @@ impl HashSessionToken { } } - pub fn hashed(&self) -> &str { - &*self.hashed + pub fn hashed(&self) -> &InternedString { + &self.hashed } fn hash(token: &str) -> InternedString { @@ -241,6 +261,7 @@ pub struct Auth { cookie: Option, is_login: bool, set_cookie: Option, + user_agent: Option, } impl Auth { pub fn new() -> Self { @@ -249,6 +270,7 @@ impl Auth { cookie: None, is_login: false, set_cookie: None, + user_agent: None, } } } @@ -260,7 +282,8 @@ impl Middleware for Auth { _: &RpcContext, request: &mut Request, ) -> Result<(), Response> { - self.cookie = request.headers_mut().get(COOKIE).cloned(); + self.cookie = request.headers_mut().remove(COOKIE); + self.user_agent = request.headers_mut().remove(USER_AGENT); Ok(()) } async fn process_rpc_request( @@ -282,6 +305,10 @@ impl Middleware for Auth { .into()), }); } + if let Some(user_agent) = self.user_agent.as_ref().and_then(|h| h.to_str().ok()) { + request.params["user-agent"] = Value::String(Arc::new(user_agent.to_owned())) + // TODO: will this panic? + } } else if metadata.authenticated { match HasValidSession::from_header(self.cookie.as_ref(), &context).await { Err(e) => { @@ -291,7 +318,8 @@ impl Middleware for Auth { }) } Ok(HasValidSession(SessionType::Session(s))) if metadata.get_session => { - request.params["session"] = Value::String(Arc::new(s.hashed().into())); + request.params["session"] = + Value::String(Arc::new(s.hashed().deref().to_owned())); // TODO: will this panic? } _ => (), diff --git a/core/startos/src/net/dns.rs b/core/startos/src/net/dns.rs index 9eb5d3750..ba69b6c16 100644 --- a/core/startos/src/net/dns.rs +++ b/core/startos/src/net/dns.rs @@ -18,6 +18,7 @@ use trust_dns_server::proto::rr::{Name, Record, RecordType}; use trust_dns_server::server::{Request, RequestHandler, ResponseHandler, ResponseInfo}; use trust_dns_server::ServerFuture; +use crate::net::forward::START9_BRIDGE_IFACE; use crate::util::Invoke; use crate::{Error, ErrorKind, ResultExt}; @@ -163,13 +164,13 @@ impl DnsController { Command::new("resolvectl") .arg("dns") - .arg("lxcbr0") + .arg(START9_BRIDGE_IFACE) .arg("127.0.0.1") .invoke(ErrorKind::Network) .await?; Command::new("resolvectl") .arg("domain") - .arg("lxcbr0") + .arg(START9_BRIDGE_IFACE) .arg("embassy") .invoke(ErrorKind::Network) .await?; diff --git a/core/startos/src/net/forward.rs b/core/startos/src/net/forward.rs new file mode 100644 index 000000000..e954bc36a --- /dev/null +++ b/core/startos/src/net/forward.rs @@ -0,0 +1,177 @@ +use std::collections::BTreeMap; +use std::net::SocketAddr; +use std::sync::{Arc, Weak}; + +use id_pool::IdPool; +use serde::{Deserialize, Serialize}; +use tokio::process::Command; +use tokio::sync::Mutex; + +use crate::prelude::*; +use crate::util::Invoke; + +pub const START9_BRIDGE_IFACE: &str = "lxcbr0"; +pub const FIRST_DYNAMIC_PRIVATE_PORT: u16 = 49152; + +#[derive(Debug, Deserialize, Serialize)] +pub struct AvailablePorts(IdPool); +impl AvailablePorts { + pub fn new() -> Self { + Self(IdPool::new_ranged(FIRST_DYNAMIC_PRIVATE_PORT..u16::MAX)) + } + pub fn alloc(&mut self) -> Result { + self.0.request_id().ok_or_else(|| { + Error::new( + eyre!("No more dynamic ports available!"), + ErrorKind::Network, + ) + }) + } + pub fn free(&mut self, ports: impl IntoIterator) { + for port in ports { + self.0.return_id(port).unwrap_or_default(); + } + } +} + +pub struct LanPortForwardController { + forwards: Mutex>>>, +} +impl LanPortForwardController { + pub fn new() -> Self { + Self { + forwards: Mutex::new(BTreeMap::new()), + } + } + pub async fn add(&self, port: u16, addr: SocketAddr) -> Result, Error> { + let mut writable = self.forwards.lock().await; + let (prev, mut forward) = if let Some(forward) = writable.remove(&port) { + ( + forward.keys().next().cloned(), + forward + .into_iter() + .filter(|(_, rc)| rc.strong_count() > 0) + .collect(), + ) + } else { + (None, BTreeMap::new()) + }; + let rc = Arc::new(()); + forward.insert(addr, Arc::downgrade(&rc)); + let next = forward.keys().next().cloned(); + if !forward.is_empty() { + writable.insert(port, forward); + } + + update_forward(port, prev, next).await?; + Ok(rc) + } + pub async fn gc(&self, external: u16) -> Result<(), Error> { + let mut writable = self.forwards.lock().await; + let (prev, forward) = if let Some(forward) = writable.remove(&external) { + ( + forward.keys().next().cloned(), + forward + .into_iter() + .filter(|(_, rc)| rc.strong_count() > 0) + .collect(), + ) + } else { + (None, BTreeMap::new()) + }; + let next = forward.keys().next().cloned(); + if !forward.is_empty() { + writable.insert(external, forward); + } + + update_forward(external, prev, next).await + } +} + +async fn update_forward( + external: u16, + prev: Option, + next: Option, +) -> Result<(), Error> { + if prev != next { + if let Some(prev) = prev { + unforward(START9_BRIDGE_IFACE, external, prev).await?; + } + if let Some(next) = next { + forward(START9_BRIDGE_IFACE, external, next).await?; + } + } + Ok(()) +} + +// iptables -I FORWARD -o br-start9 -p tcp -d 172.18.0.2 --dport 8333 -j ACCEPT +// iptables -t nat -I PREROUTING -p tcp --dport 32768 -j DNAT --to 172.18.0.2:8333 +async fn forward(iface: &str, external: u16, addr: SocketAddr) -> Result<(), Error> { + Command::new("iptables") + .arg("-I") + .arg("FORWARD") + .arg("-o") + .arg(iface) + .arg("-p") + .arg("tcp") + .arg("-d") + .arg(addr.ip().to_string()) + .arg("--dport") + .arg(addr.port().to_string()) + .arg("-j") + .arg("ACCEPT") + .invoke(crate::ErrorKind::Network) + .await?; + Command::new("iptables") + .arg("-t") + .arg("nat") + .arg("-I") + .arg("PREROUTING") + .arg("-p") + .arg("tcp") + .arg("--dport") + .arg(external.to_string()) + .arg("-j") + .arg("DNAT") + .arg("--to") + .arg(addr.to_string()) + .invoke(crate::ErrorKind::Network) + .await?; + Ok(()) +} + +// iptables -D FORWARD -o br-start9 -p tcp -d 172.18.0.2 --dport 8333 -j ACCEPT +// iptables -t nat -D PREROUTING -p tcp --dport 32768 -j DNAT --to 172.18.0.2:8333 +async fn unforward(iface: &str, external: u16, addr: SocketAddr) -> Result<(), Error> { + Command::new("iptables") + .arg("-D") + .arg("FORWARD") + .arg("-o") + .arg(iface) + .arg("-p") + .arg("tcp") + .arg("-d") + .arg(addr.ip().to_string()) + .arg("--dport") + .arg(addr.port().to_string()) + .arg("-j") + .arg("ACCEPT") + .invoke(crate::ErrorKind::Network) + .await?; + Command::new("iptables") + .arg("-t") + .arg("nat") + .arg("-D") + .arg("PREROUTING") + .arg("-p") + .arg("tcp") + .arg("--dport") + .arg(external.to_string()) + .arg("-j") + .arg("DNAT") + .arg("--to") + .arg(addr.to_string()) + .invoke(crate::ErrorKind::Network) + .await?; + Ok(()) +} diff --git a/core/startos/src/net/host/address.rs b/core/startos/src/net/host/address.rs new file mode 100644 index 000000000..6f3ff6df2 --- /dev/null +++ b/core/startos/src/net/host/address.rs @@ -0,0 +1,9 @@ +use serde::{Deserialize, Serialize}; +use torut::onion::OnionAddressV3; + +#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord)] +#[serde(rename_all = "camelCase")] +#[serde(tag = "kind")] +pub enum HostAddress { + Onion { address: OnionAddressV3 }, +} diff --git a/core/startos/src/net/host/binding.rs b/core/startos/src/net/host/binding.rs new file mode 100644 index 000000000..0584b517b --- /dev/null +++ b/core/startos/src/net/host/binding.rs @@ -0,0 +1,71 @@ +use imbl_value::InternedString; +use serde::{Deserialize, Serialize}; + +use crate::net::forward::AvailablePorts; +use crate::net::vhost::AlpnInfo; +use crate::prelude::*; + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct BindInfo { + pub options: BindOptions, + pub assigned_lan_port: Option, +} +impl BindInfo { + pub fn new(available_ports: &mut AvailablePorts, options: BindOptions) -> Result { + let mut assigned_lan_port = None; + if options.add_ssl.is_some() || options.secure { + assigned_lan_port = Some(available_ports.alloc()?); + } + Ok(Self { + options, + assigned_lan_port, + }) + } + pub fn update( + self, + available_ports: &mut AvailablePorts, + options: BindOptions, + ) -> Result { + let Self { + mut assigned_lan_port, + .. + } = self; + if options.add_ssl.is_some() || options.secure { + assigned_lan_port = if let Some(port) = assigned_lan_port.take() { + Some(port) + } else { + Some(available_ports.alloc()?) + }; + } else { + if let Some(port) = assigned_lan_port.take() { + available_ports.free([port]); + } + } + Ok(Self { + options, + assigned_lan_port, + }) + } +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct BindOptions { + pub scheme: InternedString, + pub preferred_external_port: u16, + pub add_ssl: Option, + pub secure: bool, + pub ssl: bool, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +#[serde(rename_all = "camelCase")] +pub struct AddSslOptions { + pub scheme: InternedString, + pub preferred_external_port: u16, + // #[serde(default)] + // pub add_x_forwarded_headers: bool, // TODO + #[serde(default)] + pub alpn: AlpnInfo, +} diff --git a/core/startos/src/net/host/mod.rs b/core/startos/src/net/host/mod.rs index b2b991698..18b86ba0e 100644 --- a/core/startos/src/net/host/mod.rs +++ b/core/startos/src/net/host/mod.rs @@ -1,29 +1,84 @@ +use std::collections::{BTreeMap, BTreeSet}; + use imbl_value::InternedString; +use models::HostId; use serde::{Deserialize, Serialize}; -use crate::net::host::multi::MultiHost; +use crate::net::forward::AvailablePorts; +use crate::net::host::address::HostAddress; +use crate::net::host::binding::{BindInfo, BindOptions}; +use crate::prelude::*; -pub mod multi; +pub mod address; +pub mod binding; -pub enum Host { - Multi(MultiHost), - // Single(SingleHost), - // Static(StaticHost), +#[derive(Debug, Deserialize, Serialize, HasModel)] +#[serde(rename_all = "camelCase")] +#[model = "Model"] +pub struct Host { + pub kind: HostKind, + pub bindings: BTreeMap, + pub addresses: BTreeSet, + pub primary: Option, +} +impl AsRef for Host { + fn as_ref(&self) -> &Host { + self + } +} +impl Host { + pub fn new(kind: HostKind) -> Self { + Self { + kind, + bindings: BTreeMap::new(), + addresses: BTreeSet::new(), + primary: None, + } + } } -#[derive(Deserialize, Serialize)] -pub struct BindOptions { - scheme: InternedString, - preferred_external_port: u16, - add_ssl: Option, - secure: bool, - ssl: bool, +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub enum HostKind { + Multi, + // Single, + // Static, } -#[derive(Deserialize, Serialize)] -pub struct AddSslOptions { - scheme: InternedString, - preferred_external_port: u16, - #[serde(default)] - add_x_forwarded_headers: bool, +#[derive(Debug, Default, Deserialize, Serialize, HasModel)] +#[model = "Model"] +pub struct HostInfo(BTreeMap); + +impl Map for HostInfo { + type Key = HostId; + type Value = Host; + fn key_str(key: &Self::Key) -> Result, Error> { + Ok(key) + } + fn key_string(key: &Self::Key) -> Result { + Ok(key.clone().into()) + } +} + +impl Model { + pub fn add_binding( + &mut self, + available_ports: &mut AvailablePorts, + kind: HostKind, + id: &HostId, + internal_port: u16, + options: BindOptions, + ) -> Result<(), Error> { + self.upsert(id, || Host::new(kind))? + .as_bindings_mut() + .mutate(|b| { + let info = if let Some(info) = b.remove(&internal_port) { + info.update(available_ports, options)? + } else { + BindInfo::new(available_ports, options)? + }; + b.insert(internal_port, info); + Ok(()) + }) // TODO: handle host kind change + } } diff --git a/core/startos/src/net/host/multi.rs b/core/startos/src/net/host/multi.rs deleted file mode 100644 index 511619201..000000000 --- a/core/startos/src/net/host/multi.rs +++ /dev/null @@ -1,13 +0,0 @@ -use std::collections::BTreeMap; - -use imbl_value::InternedString; -use serde::{Deserialize, Serialize}; - -use crate::net::host::BindOptions; -use crate::net::keys::Key; - -pub struct MultiHost { - id: InternedString, - key: Key, - binds: BTreeMap, -} diff --git a/core/startos/src/net/keys.rs b/core/startos/src/net/keys.rs index 1079d4a98..02ec17329 100644 --- a/core/startos/src/net/keys.rs +++ b/core/startos/src/net/keys.rs @@ -1,393 +1,24 @@ -use clap::Parser; -use color_eyre::eyre::eyre; -use models::{HostId, Id, PackageId}; -use openssl::pkey::{PKey, Private}; -use openssl::sha::Sha256; -use openssl::x509::X509; -use p256::elliptic_curve::pkcs8::EncodePrivateKey; use serde::{Deserialize, Serialize}; -use sqlx::{Acquire, PgExecutor}; -use ssh_key::private::Ed25519PrivateKey; -use torut::onion::{OnionAddressV3, TorSecretKeyV3}; -use zeroize::Zeroize; -use crate::config::ConfigureContext; -use crate::context::RpcContext; -use crate::control::{restart, ControlParams}; -use crate::disk::fsck::RequiresReboot; -use crate::net::ssl::CertPair; +use crate::account::AccountInfo; +use crate::net::ssl::CertStore; +use crate::net::tor::OnionStore; use crate::prelude::*; -use crate::util::crypto::ed25519_expand_key; -// TODO: delete once we may change tor addresses -async fn compat( - secrets: impl PgExecutor<'_>, - host: &Option<(PackageId, HostId)>, -) -> Result, Error> { - if let Some((package, host)) = host { - if let Some(r) = sqlx::query!( - "SELECT key FROM tor WHERE package = $1 AND interface = $2", - package, - host - ) - .fetch_optional(secrets) - .await? - { - Ok(Some(<[u8; 64]>::try_from(r.key).map_err(|e| { - Error::new( - eyre!("expected vec of len 64, got len {}", e.len()), - ErrorKind::ParseDbField, - ) - })?)) - } else { - Ok(None) - } - } else if let Some(key) = sqlx::query!("SELECT tor_key FROM account WHERE id = 0") - .fetch_one(secrets) - .await? - .tor_key - { - Ok(Some(<[u8; 64]>::try_from(key).map_err(|e| { - Error::new( - eyre!("expected vec of len 64, got len {}", e.len()), - ErrorKind::ParseDbField, - ) - })?)) - } else { - Ok(None) - } +#[derive(Debug, Deserialize, Serialize, HasModel)] +#[model = "Model"] +pub struct KeyStore { + pub onion: OnionStore, + pub local_certs: CertStore, + // pub letsencrypt_certs: BTreeMap, CertData> } - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Key { - host: Option<(PackageId, HostId)>, - base: [u8; 32], - tor_key: [u8; 64], // Does NOT necessarily match base -} -impl Key { - pub fn host(&self) -> Option<(PackageId, HostId)> { - self.host.clone() - } - pub fn as_bytes(&self) -> [u8; 32] { - self.base - } - pub fn internal_address(&self) -> String { - self.host - .as_ref() - .map(|(pkg_id, _)| format!("{}.embassy", pkg_id)) - .unwrap_or_else(|| "embassy".to_owned()) - } - pub fn tor_key(&self) -> TorSecretKeyV3 { - self.tor_key.into() - } - pub fn tor_address(&self) -> OnionAddressV3 { - self.tor_key().public().get_onion_address() - } - pub fn base_address(&self) -> String { - self.tor_key() - .public() - .get_onion_address() - .get_address_without_dot_onion() - } - pub fn local_address(&self) -> String { - self.base_address() + ".local" - } - pub fn openssl_key_ed25519(&self) -> PKey { - PKey::private_key_from_raw_bytes(&self.base, openssl::pkey::Id::ED25519).unwrap() - } - pub fn openssl_key_nistp256(&self) -> PKey { - let mut buf = self.base; - loop { - if let Ok(k) = p256::SecretKey::from_slice(&buf) { - return PKey::private_key_from_pkcs8(&*k.to_pkcs8_der().unwrap().as_bytes()) - .unwrap(); - } - let mut sha = Sha256::new(); - sha.update(&buf); - buf = sha.finish(); - } - } - pub fn ssh_key(&self) -> Ed25519PrivateKey { - Ed25519PrivateKey::from_bytes(&self.base) - } - pub(crate) fn from_pair( - host: Option<(PackageId, HostId)>, - bytes: [u8; 32], - tor_key: [u8; 64], - ) -> Self { - Self { - host, - tor_key, - base: bytes, - } - } - pub fn from_bytes(host: Option<(PackageId, HostId)>, bytes: [u8; 32]) -> Self { - Self::from_pair(host, bytes, ed25519_expand_key(&bytes)) - } - pub fn new(host: Option<(PackageId, HostId)>) -> Self { - Self::from_bytes(host, rand::random()) - } - pub(super) fn with_certs(self, certs: CertPair, int: X509, root: X509) -> KeyInfo { - KeyInfo { - key: self, - certs, - int, - root, - } - } - pub async fn for_package( - secrets: impl PgExecutor<'_>, - package: &PackageId, - ) -> Result, Error> { - sqlx::query!( - r#" - SELECT - network_keys.package, - network_keys.interface, - network_keys.key, - tor.key AS "tor_key?" - FROM - network_keys - LEFT JOIN - tor - ON - network_keys.package = tor.package - AND - network_keys.interface = tor.interface - WHERE - network_keys.package = $1 - "#, - package - ) - .fetch_all(secrets) - .await? - .into_iter() - .map(|row| { - let host = Some((package.clone(), HostId::from(Id::try_from(row.interface)?))); - let bytes = row.key.try_into().map_err(|e: Vec| { - Error::new( - eyre!("Invalid length for network key {} expected 32", e.len()), - crate::ErrorKind::Database, - ) - })?; - Ok(match row.tor_key { - Some(tor_key) => Key::from_pair( - host, - bytes, - tor_key.try_into().map_err(|e: Vec| { - Error::new( - eyre!("Invalid length for tor key {} expected 64", e.len()), - crate::ErrorKind::Database, - ) - })?, - ), - None => Key::from_bytes(host, bytes), - }) - }) - .collect() - } - pub async fn for_host( - secrets: &mut Ex, - host: Option<(PackageId, HostId)>, - ) -> Result - where - for<'a> &'a mut Ex: PgExecutor<'a>, - { - let tentative = rand::random::<[u8; 32]>(); - let actual = if let Some((pkg, iface)) = &host { - let k = tentative.as_slice(); - let actual = sqlx::query!( - "INSERT INTO network_keys (package, interface, key) VALUES ($1, $2, $3) ON CONFLICT (package, interface) DO UPDATE SET package = EXCLUDED.package RETURNING key", - pkg, - iface, - k, - ) - .fetch_one(&mut *secrets) - .await?.key; - let mut bytes = tentative; - bytes.clone_from_slice(actual.get(0..32).ok_or_else(|| { - Error::new( - eyre!("Invalid key size returned from DB"), - crate::ErrorKind::Database, - ) - })?); - bytes - } else { - let actual = sqlx::query!("SELECT network_key FROM account WHERE id = 0") - .fetch_one(&mut *secrets) - .await? - .network_key; - let mut bytes = tentative; - bytes.clone_from_slice(actual.get(0..32).ok_or_else(|| { - Error::new( - eyre!("Invalid key size returned from DB"), - crate::ErrorKind::Database, - ) - })?); - bytes +impl KeyStore { + pub fn new(account: &AccountInfo) -> Result { + let mut res = Self { + onion: OnionStore::new(), + local_certs: CertStore::new(account)?, }; - let mut res = Self::from_bytes(host, actual); - if let Some(tor_key) = compat(secrets, &res.host).await? { - res.tor_key = tor_key; - } + res.onion.insert(account.tor_key.clone()); Ok(res) } } -impl Drop for Key { - fn drop(&mut self) { - self.base.zeroize(); - self.tor_key.zeroize(); - } -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub struct KeyInfo { - key: Key, - certs: CertPair, - int: X509, - root: X509, -} -impl KeyInfo { - pub fn key(&self) -> &Key { - &self.key - } - pub fn certs(&self) -> &CertPair { - &self.certs - } - pub fn int_ca(&self) -> &X509 { - &self.int - } - pub fn root_ca(&self) -> &X509 { - &self.root - } - pub fn fullchain_ed25519(&self) -> Vec<&X509> { - vec![&self.certs.ed25519, &self.int, &self.root] - } - pub fn fullchain_nistp256(&self) -> Vec<&X509> { - vec![&self.certs.nistp256, &self.int, &self.root] - } -} - -#[test] -pub fn test_keygen() { - let key = Key::new(None); - key.tor_key(); - key.openssl_key_nistp256(); -} - -pub fn display_requires_reboot(_: RotateKeysParams, args: RequiresReboot) { - if args.0 { - println!("Server must be restarted for changes to take effect"); - } -} -#[derive(Deserialize, Serialize, Parser)] -#[serde(rename_all = "kebab-case")] -#[command(rename_all = "kebab-case")] -pub struct RotateKeysParams { - package: Option, - host: Option, -} - -// #[command(display(display_requires_reboot))] -pub async fn rotate_key( - ctx: RpcContext, - RotateKeysParams { package, host }: RotateKeysParams, -) -> Result { - let mut pgcon = ctx.secret_store.acquire().await?; - let mut tx = pgcon.begin().await?; - if let Some(package) = package { - let Some(host) = host else { - return Err(Error::new( - eyre!("Must specify host"), - ErrorKind::InvalidRequest, - )); - }; - sqlx::query!( - "DELETE FROM tor WHERE package = $1 AND interface = $2", - &package, - &host, - ) - .execute(&mut *tx) - .await?; - sqlx::query!( - "DELETE FROM network_keys WHERE package = $1 AND interface = $2", - &package, - &host, - ) - .execute(&mut *tx) - .await?; - let new_key = Key::for_host(&mut *tx, Some((package.clone(), host.clone()))).await?; - let needs_config = ctx - .db - .mutate(|v| { - let installed = v - .as_public_mut() - .as_package_data_mut() - .as_idx_mut(&package) - .or_not_found(&package)? - .as_installed_mut() - .or_not_found("installed")?; - let addrs = installed - .as_interface_addresses_mut() - .as_idx_mut(&host) - .or_not_found(&host)?; - if let Some(lan) = addrs.as_lan_address_mut().transpose_mut() { - lan.ser(&new_key.local_address())?; - } - if let Some(lan) = addrs.as_tor_address_mut().transpose_mut() { - lan.ser(&new_key.tor_address().to_string())?; - } - - // TODO - // if installed - // .as_manifest() - // .as_config() - // .transpose_ref() - // .is_some() - // { - // installed - // .as_status_mut() - // .as_configured_mut() - // .replace(&false) - // } else { - // Ok(false) - // } - Ok(false) - }) - .await?; - tx.commit().await?; - if needs_config { - ctx.services - .get(&package) - .await - .as_ref() - .ok_or_else(|| { - Error::new( - eyre!("There is no manager running for {package}"), - ErrorKind::Unknown, - ) - })? - .configure(ConfigureContext::default()) - .await?; - } else { - restart(ctx, ControlParams { id: package }).await?; - } - Ok(RequiresReboot(false)) - } else { - sqlx::query!("UPDATE account SET tor_key = NULL, network_key = gen_random_bytes(32)") - .execute(&mut *tx) - .await?; - let new_key = Key::for_host(&mut *tx, None).await?; - let url = format!("https://{}", new_key.tor_address()).parse()?; - ctx.db - .mutate(|v| { - v.as_public_mut() - .as_server_info_mut() - .as_tor_address_mut() - .ser(&url) - }) - .await?; - tx.commit().await?; - Ok(RequiresReboot(true)) - } -} diff --git a/core/startos/src/net/mdns.rs b/core/startos/src/net/mdns.rs index ee2e0fa41..af5d128a8 100644 --- a/core/startos/src/net/mdns.rs +++ b/core/startos/src/net/mdns.rs @@ -1,14 +1,10 @@ -use std::collections::BTreeMap; use std::net::Ipv4Addr; -use std::sync::{Arc, Weak}; use color_eyre::eyre::eyre; -use tokio::process::{Child, Command}; -use tokio::sync::Mutex; -use tracing::instrument; +use tokio::process::Command; +use crate::prelude::*; use crate::util::Invoke; -use crate::{Error, ResultExt}; pub async fn resolve_mdns(hostname: &str) -> Result { Ok(String::from_utf8( diff --git a/core/startos/src/net/mod.rs b/core/startos/src/net/mod.rs index a0a2ed166..f6e5ddee5 100644 --- a/core/startos/src/net/mod.rs +++ b/core/startos/src/net/mod.rs @@ -1,9 +1,8 @@ -use rpc_toolkit::{from_fn_async, AnyContext, HandlerExt, ParentHandler}; - -use crate::context::CliContext; +use rpc_toolkit::ParentHandler; pub mod dhcp; pub mod dns; +pub mod forward; pub mod host; pub mod keys; pub mod mdns; @@ -22,13 +21,4 @@ pub fn net() -> ParentHandler { ParentHandler::new() .subcommand("tor", tor::tor()) .subcommand("dhcp", dhcp::dhcp()) - .subcommand("ssl", ssl::ssl()) - .subcommand( - "rotate-key", - from_fn_async(keys::rotate_key) - .with_custom_display_fn::(|handle, result| { - Ok(keys::display_requires_reboot(handle.params, result)) - }) - .with_remote_cli::(), - ) } diff --git a/core/startos/src/net/net_controller.rs b/core/startos/src/net/net_controller.rs index 9b9145531..d9d7a5d76 100644 --- a/core/startos/src/net/net_controller.rs +++ b/core/startos/src/net/net_controller.rs @@ -1,59 +1,72 @@ -use std::collections::BTreeMap; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::collections::{BTreeMap, BTreeSet}; +use std::net::{Ipv4Addr, SocketAddr}; use std::sync::{Arc, Weak}; use color_eyre::eyre::eyre; -use models::{HostId, PackageId}; -use sqlx::PgExecutor; +use imbl::OrdMap; +use lazy_format::lazy_format; +use models::{HostId, OptionExt, PackageId}; +use patch_db::PatchDb; +use torut::onion::{OnionAddressV3, TorSecretKeyV3}; use tracing::instrument; +use crate::db::prelude::PatchDbExt; use crate::error::ErrorCollection; use crate::hostname::Hostname; use crate::net::dns::DnsController; -use crate::net::keys::Key; -use crate::net::ssl::{export_cert, export_key, SslManager}; +use crate::net::forward::LanPortForwardController; +use crate::net::host::address::HostAddress; +use crate::net::host::binding::{AddSslOptions, BindOptions}; +use crate::net::host::{Host, HostKind}; use crate::net::tor::TorController; use crate::net::vhost::{AlpnInfo, VHostController}; -use crate::volume::cert_dir; +use crate::util::serde::MaybeUtf8String; use crate::{Error, HOST_IP}; pub struct NetController { + db: PatchDb, pub(super) tor: TorController, pub(super) vhost: VHostController, pub(super) dns: DnsController, - pub(super) ssl: Arc, + pub(super) forward: LanPortForwardController, pub(super) os_bindings: Vec>, } impl NetController { #[instrument(skip_all)] pub async fn init( + db: PatchDb, tor_control: SocketAddr, tor_socks: SocketAddr, dns_bind: &[SocketAddr], - ssl: SslManager, hostname: &Hostname, - os_key: &Key, + os_tor_key: TorSecretKeyV3, ) -> Result { - let ssl = Arc::new(ssl); let mut res = Self { + db: db.clone(), tor: TorController::new(tor_control, tor_socks), - vhost: VHostController::new(ssl.clone()), + vhost: VHostController::new(db), dns: DnsController::init(dns_bind).await?, - ssl, + forward: LanPortForwardController::new(), os_bindings: Vec::new(), }; - res.add_os_bindings(hostname, os_key).await?; + res.add_os_bindings(hostname, os_tor_key).await?; Ok(res) } - async fn add_os_bindings(&mut self, hostname: &Hostname, key: &Key) -> Result<(), Error> { - let alpn = Err(AlpnInfo::Specified(vec!["http/1.1".into(), "h2".into()])); + async fn add_os_bindings( + &mut self, + hostname: &Hostname, + tor_key: TorSecretKeyV3, + ) -> Result<(), Error> { + let alpn = Err(AlpnInfo::Specified(vec![ + MaybeUtf8String("http/1.1".into()), + MaybeUtf8String("h2".into()), + ])); // Internal DNS self.vhost .add( - key.clone(), Some("embassy".into()), 443, ([127, 0, 0, 1], 80).into(), @@ -66,13 +79,7 @@ impl NetController { // LAN IP self.os_bindings.push( self.vhost - .add( - key.clone(), - None, - 443, - ([127, 0, 0, 1], 80).into(), - alpn.clone(), - ) + .add(None, 443, ([127, 0, 0, 1], 80).into(), alpn.clone()) .await?, ); @@ -80,7 +87,6 @@ impl NetController { self.os_bindings.push( self.vhost .add( - key.clone(), Some("localhost".into()), 443, ([127, 0, 0, 1], 80).into(), @@ -91,7 +97,6 @@ impl NetController { self.os_bindings.push( self.vhost .add( - key.clone(), Some(hostname.no_dot_host_name()), 443, ([127, 0, 0, 1], 80).into(), @@ -104,7 +109,6 @@ impl NetController { self.os_bindings.push( self.vhost .add( - key.clone(), Some(hostname.local_domain_name()), 443, ([127, 0, 0, 1], 80).into(), @@ -113,28 +117,26 @@ impl NetController { .await?, ); - // Tor (http) - self.os_bindings.push( - self.tor - .add(key.tor_key(), 80, ([127, 0, 0, 1], 80).into()) - .await?, - ); - - // Tor (https) + // Tor self.os_bindings.push( self.vhost .add( - key.clone(), - Some(key.tor_address().to_string()), + Some(tor_key.public().get_onion_address().to_string()), 443, ([127, 0, 0, 1], 80).into(), alpn.clone(), ) .await?, ); - self.os_bindings.push( + self.os_bindings.extend( self.tor - .add(key.tor_key(), 443, ([127, 0, 0, 1], 443).into()) + .add( + tor_key, + vec![ + (80, ([127, 0, 0, 1], 80).into()), // http + (443, ([127, 0, 0, 1], 443).into()), // https + ], + ) .await?, ); @@ -155,57 +157,15 @@ impl NetController { ip, dns, controller: Arc::downgrade(self), - tor: BTreeMap::new(), - lan: BTreeMap::new(), + binds: BTreeMap::new(), }) } +} - async fn add_tor( - &self, - key: &Key, - external: u16, - target: SocketAddr, - ) -> Result>, Error> { - let mut rcs = Vec::with_capacity(1); - rcs.push(self.tor.add(key.tor_key(), external, target).await?); - Ok(rcs) - } - - async fn remove_tor(&self, key: &Key, external: u16, rcs: Vec>) -> Result<(), Error> { - drop(rcs); - self.tor.gc(Some(key.tor_key()), Some(external)).await - } - - async fn add_lan( - &self, - key: Key, - external: u16, - target: SocketAddr, - connect_ssl: Result<(), AlpnInfo>, - ) -> Result>, Error> { - let mut rcs = Vec::with_capacity(2); - rcs.push( - self.vhost - .add( - key.clone(), - Some(key.local_address()), - external, - target.into(), - connect_ssl, - ) - .await?, - ); - // rcs.push(self.mdns.add(key.base_address()).await?); - // TODO - Ok(rcs) - } - - async fn remove_lan(&self, key: &Key, external: u16, rcs: Vec>) -> Result<(), Error> { - drop(rcs); - // self.mdns.gc(key.base_address()).await?; - // TODO - self.vhost.gc(Some(key.local_address()), external).await - } +#[derive(Default)] +struct HostBinds { + lan: BTreeMap, Arc<()>)>, + tor: BTreeMap, Vec>)>, } pub struct NetService { @@ -214,8 +174,7 @@ pub struct NetService { ip: Ipv4Addr, dns: Arc<()>, controller: Weak, - tor: BTreeMap<(HostId, u16), (Key, Vec>)>, - lan: BTreeMap<(HostId, u16), (Key, Vec>)>, + binds: BTreeMap, } impl NetService { fn net_controller(&self) -> Result, Error> { @@ -226,111 +185,196 @@ impl NetService { ) }) } - pub async fn add_tor( + + pub async fn bind( &mut self, - secrets: &mut Ex, + kind: HostKind, id: HostId, - external: u16, - internal: u16, - ) -> Result<(), Error> - where - for<'a> &'a mut Ex: PgExecutor<'a>, - { - let key = Key::for_host(secrets, Some((self.id.clone(), id.clone()))).await?; - let ctrl = self.net_controller()?; - let tor_idx = (id, external); - let mut tor = self - .tor - .remove(&tor_idx) - .unwrap_or_else(|| (key.clone(), Vec::new())); - tor.1.append( - &mut ctrl - .add_tor(&key, external, SocketAddr::new(self.ip.into(), internal)) - .await?, - ); - self.tor.insert(tor_idx, tor); - Ok(()) + internal_port: u16, + options: BindOptions, + ) -> Result<(), Error> { + let id_ref = &id; + let pkg_id = &self.id; + let host = self + .net_controller()? + .db + .mutate(|d| { + let mut ports = d.as_private().as_available_ports().de()?; + let hosts = d + .as_public_mut() + .as_package_data_mut() + .as_idx_mut(pkg_id) + .or_not_found(pkg_id)? + .as_installed_mut() + .or_not_found(pkg_id)? + .as_hosts_mut(); + hosts.add_binding(&mut ports, kind, &id, internal_port, options)?; + let host = hosts + .as_idx(&id) + .or_not_found(lazy_format!("Host {id_ref} for {pkg_id}"))? + .de()?; + d.as_private_mut().as_available_ports_mut().ser(&ports)?; + Ok(host) + }) + .await?; + self.update(id, host).await } - pub async fn remove_tor(&mut self, id: HostId, external: u16) -> Result<(), Error> { + + async fn update(&mut self, id: HostId, host: Host) -> Result<(), Error> { let ctrl = self.net_controller()?; - if let Some((key, rcs)) = self.tor.remove(&(id, external)) { - ctrl.remove_tor(&key, external, rcs).await?; + let binds = { + if !self.binds.contains_key(&id) { + self.binds.insert(id.clone(), Default::default()); + } + self.binds.get_mut(&id).unwrap() + }; + if true + // TODO: if should listen lan + { + for (port, bind) in &host.bindings { + let old_lan_bind = binds.lan.remove(port); + let old_lan_port = old_lan_bind.as_ref().map(|(external, _, _)| *external); + let lan_bind = old_lan_bind.filter(|(external, ssl, _)| { + ssl == &bind.options.add_ssl + && bind.assigned_lan_port.as_ref() == Some(external) + }); // only keep existing binding if relevant details match + if let Some(external) = bind.assigned_lan_port { + let new_lan_bind = if let Some(b) = lan_bind { + b + } else { + if let Some(ssl) = &bind.options.add_ssl { + let rc = ctrl + .vhost + .add( + None, + external, + (self.ip, *port).into(), + if bind.options.ssl { + Ok(()) + } else { + Err(ssl.alpn.clone()) + }, + ) + .await?; + (*port, Some(ssl.clone()), rc) + } else { + let rc = ctrl.forward.add(external, (self.ip, *port).into()).await?; + (*port, None, rc) + } + }; + binds.lan.insert(*port, new_lan_bind); + } + if let Some(external) = old_lan_port { + ctrl.vhost.gc(None, external).await?; + ctrl.forward.gc(external).await?; + } + } + let mut removed = BTreeSet::new(); + let mut removed_ssl = BTreeSet::new(); + binds.lan.retain(|internal, (external, ssl, _)| { + if host.bindings.contains_key(internal) { + true + } else { + if ssl.is_some() { + removed_ssl.insert(*external); + } else { + removed.insert(*external); + } + false + } + }); + for external in removed { + ctrl.forward.gc(external).await?; + } + for external in removed_ssl { + ctrl.vhost.gc(None, external).await?; + } + } + let tor_binds: OrdMap = host + .bindings + .iter() + .flat_map(|(internal, info)| { + let non_ssl = ( + info.options.preferred_external_port, + SocketAddr::from((self.ip, *internal)), + ); + if let (Some(ssl), Some(ssl_internal)) = + (&info.options.add_ssl, info.assigned_lan_port) + { + itertools::Either::Left( + [ + ( + ssl.preferred_external_port, + SocketAddr::from(([127, 0, 0, 1], ssl_internal)), + ), + non_ssl, + ] + .into_iter(), + ) + } else { + itertools::Either::Right([non_ssl].into_iter()) + } + }) + .collect(); + let mut keep_tor_addrs = BTreeSet::new(); + for addr in match host.kind { + HostKind::Multi => { + // itertools::Either::Left( + host.addresses.iter() + // ) + } // HostKind::Single | HostKind::Static => itertools::Either::Right(&host.primary), + } { + match addr { + HostAddress::Onion { address } => { + keep_tor_addrs.insert(address); + let old_tor_bind = binds.tor.remove(address); + let tor_bind = old_tor_bind.filter(|(ports, _)| ports == &tor_binds); + let new_tor_bind = if let Some(tor_bind) = tor_bind { + tor_bind + } else { + let key = ctrl + .db + .peek() + .await + .into_private() + .into_key_store() + .into_onion() + .get_key(address)?; + let rcs = ctrl + .tor + .add(key, tor_binds.clone().into_iter().collect()) + .await?; + (tor_binds.clone(), rcs) + }; + binds.tor.insert(address.clone(), new_tor_bind); + } + } + } + for addr in binds.tor.keys() { + if !keep_tor_addrs.contains(addr) { + ctrl.tor.gc(Some(addr.clone()), None).await?; + } } Ok(()) } - pub async fn add_lan( - &mut self, - secrets: &mut Ex, - id: HostId, - external: u16, - internal: u16, - connect_ssl: Result<(), AlpnInfo>, - ) -> Result<(), Error> - where - for<'a> &'a mut Ex: PgExecutor<'a>, - { - let key = Key::for_host(secrets, Some((self.id.clone(), id.clone()))).await?; - let ctrl = self.net_controller()?; - let lan_idx = (id, external); - let mut lan = self - .lan - .remove(&lan_idx) - .unwrap_or_else(|| (key.clone(), Vec::new())); - lan.1.append( - &mut ctrl - .add_lan( - key, - external, - SocketAddr::new(self.ip.into(), internal), - connect_ssl, - ) - .await?, - ); - self.lan.insert(lan_idx, lan); - Ok(()) - } - pub async fn remove_lan(&mut self, id: HostId, external: u16) -> Result<(), Error> { - let ctrl = self.net_controller()?; - if let Some((key, rcs)) = self.lan.remove(&(id, external)) { - ctrl.remove_lan(&key, external, rcs).await?; - } - Ok(()) - } - pub async fn export_cert( - &self, - secrets: &mut Ex, - id: &HostId, - ip: IpAddr, - ) -> Result<(), Error> - where - for<'a> &'a mut Ex: PgExecutor<'a>, - { - let key = Key::for_host(secrets, Some((self.id.clone(), id.clone()))).await?; - let ctrl = self.net_controller()?; - let cert = ctrl.ssl.with_certs(key, ip).await?; - let cert_dir = cert_dir(&self.id, id); - tokio::fs::create_dir_all(&cert_dir).await?; - export_key( - &cert.key().openssl_key_nistp256(), - &cert_dir.join(format!("{id}.key.pem")), - ) - .await?; - export_cert( - &cert.fullchain_nistp256(), - &cert_dir.join(format!("{id}.cert.pem")), - ) - .await?; // TODO: can upgrade to ed25519? - Ok(()) - } + pub async fn remove_all(mut self) -> Result<(), Error> { self.shutdown = true; let mut errors = ErrorCollection::new(); if let Some(ctrl) = Weak::upgrade(&self.controller) { - for ((_, external), (key, rcs)) in std::mem::take(&mut self.lan) { - errors.handle(ctrl.remove_lan(&key, external, rcs).await); - } - for ((_, external), (key, rcs)) in std::mem::take(&mut self.tor) { - errors.handle(ctrl.remove_tor(&key, external, rcs).await); + for (_, binds) in std::mem::take(&mut self.binds) { + for (_, (external, ssl, rc)) in binds.lan { + drop(rc); + if ssl.is_some() { + errors.handle(ctrl.vhost.gc(None, external).await); + } else { + errors.handle(ctrl.forward.gc(external).await); + } + } + for (addr, (_, rcs)) in binds.tor { + drop(rcs); + errors.handle(ctrl.tor.gc(Some(addr), None).await); + } } std::mem::take(&mut self.dns); errors.handle(ctrl.dns.gc(Some(self.id.clone()), self.ip).await); @@ -357,8 +401,7 @@ impl Drop for NetService { ip: Ipv4Addr::new(0, 0, 0, 0), dns: Default::default(), controller: Default::default(), - tor: Default::default(), - lan: Default::default(), + binds: BTreeMap::new(), }, ); tokio::spawn(async move { svc.remove_all().await.unwrap() }); diff --git a/core/startos/src/net/ssl.rs b/core/startos/src/net/ssl.rs index f9502c86b..44d7cf0da 100644 --- a/core/startos/src/net/ssl.rs +++ b/core/startos/src/net/ssl.rs @@ -5,6 +5,7 @@ use std::path::Path; use std::time::{SystemTime, UNIX_EPOCH}; use futures::FutureExt; +use imbl_value::InternedString; use libc::time_t; use openssl::asn1::{Asn1Integer, Asn1Time}; use openssl::bn::{BigNum, MsbOption}; @@ -14,17 +15,137 @@ use openssl::nid::Nid; use openssl::pkey::{PKey, Private}; use openssl::x509::{X509Builder, X509Extension, X509NameBuilder, X509}; use openssl::*; -use rpc_toolkit::{from_fn_async, HandlerExt, ParentHandler}; -use tokio::sync::{Mutex, RwLock}; +use patch_db::HasModel; +use serde::{Deserialize, Serialize}; +use tokio::sync::Mutex; use tracing::instrument; use crate::account::AccountInfo; -use crate::context::{CliContext, RpcContext}; use crate::hostname::Hostname; use crate::init::check_time_is_synchronized; -use crate::net::dhcp::ips; -use crate::net::keys::{Key, KeyInfo}; -use crate::{Error, ErrorKind, ResultExt, SOURCE_DATE}; +use crate::prelude::*; +use crate::util::serde::Pem; +use crate::SOURCE_DATE; + +#[derive(Debug, Deserialize, Serialize, HasModel)] +#[model = "Model"] +#[serde(rename_all = "kebab-case")] +pub struct CertStore { + pub root_key: Pem>, + pub root_cert: Pem, + pub int_key: Pem>, + pub int_cert: Pem, + pub leaves: BTreeMap>, CertData>, +} +impl CertStore { + pub fn new(account: &AccountInfo) -> Result { + let int_key = generate_key()?; + let int_cert = make_int_cert((&account.root_ca_key, &account.root_ca_cert), &int_key)?; + Ok(Self { + root_key: Pem::new(account.root_ca_key.clone()), + root_cert: Pem::new(account.root_ca_cert.clone()), + int_key: Pem::new(int_key), + int_cert: Pem::new(int_cert), + leaves: BTreeMap::new(), + }) + } +} +impl Model { + /// This function will grant any cert for any domain. It is up to the *caller* to enusure that the calling service has permission to sign a cert for the requested domain + pub fn cert_for( + &mut self, + hostnames: &BTreeSet, + ) -> Result { + let keys = if let Some(cert_data) = self + .as_leaves() + .as_idx(JsonKey::new_ref(hostnames)) + .map(|m| m.de()) + .transpose()? + { + if cert_data + .certs + .ed25519 + .not_before() + .compare(Asn1Time::days_from_now(0)?.as_ref())? + == Ordering::Less + && cert_data + .certs + .ed25519 + .not_after() + .compare(Asn1Time::days_from_now(30)?.as_ref())? + == Ordering::Greater + && cert_data + .certs + .nistp256 + .not_before() + .compare(Asn1Time::days_from_now(0)?.as_ref())? + == Ordering::Less + && cert_data + .certs + .nistp256 + .not_after() + .compare(Asn1Time::days_from_now(30)?.as_ref())? + == Ordering::Greater + { + return Ok(FullchainCertData { + root: self.as_root_cert().de()?.0, + int: self.as_int_cert().de()?.0, + leaf: cert_data, + }); + } + cert_data.keys + } else { + PKeyPair { + ed25519: PKey::generate_ed25519()?, + nistp256: PKey::from_ec_key(EcKey::generate(&*EcGroup::from_curve_name( + Nid::X9_62_PRIME256V1, + )?)?)?, + } + }; + let int_key = self.as_int_key().de()?.0; + let int_cert = self.as_int_cert().de()?.0; + let cert_data = CertData { + certs: CertPair { + ed25519: make_leaf_cert( + (&int_key, &int_cert), + (&keys.ed25519, &SANInfo::new(hostnames)), + )?, + nistp256: make_leaf_cert( + (&int_key, &int_cert), + (&keys.nistp256, &SANInfo::new(hostnames)), + )?, + }, + keys, + }; + self.as_leaves_mut() + .insert(JsonKey::new_ref(hostnames), &cert_data)?; + Ok(FullchainCertData { + root: self.as_root_cert().de()?.0, + int: self.as_int_cert().de()?.0, + leaf: cert_data, + }) + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CertData { + pub keys: PKeyPair, + pub certs: CertPair, +} + +pub struct FullchainCertData { + pub root: X509, + pub int: X509, + pub leaf: CertData, +} +impl FullchainCertData { + pub fn fullchain_ed25519(&self) -> Vec<&X509> { + vec![&self.root, &self.int, &self.leaf.certs.ed25519] + } + pub fn fullchain_nistp256(&self) -> Vec<&X509> { + vec![&self.root, &self.int, &self.leaf.certs.nistp256] + } +} static CERTIFICATE_VERSION: i32 = 2; // X509 version 3 is actually encoded as '2' in the cert because fuck you. @@ -35,62 +156,20 @@ fn unix_time(time: SystemTime) -> time_t { .unwrap_or_default() } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub struct CertPair { - pub ed25519: X509, - pub nistp256: X509, +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct PKeyPair { + #[serde(with = "crate::util::serde::pem")] + pub ed25519: PKey, + #[serde(with = "crate::util::serde::pem")] + pub nistp256: PKey, } -impl CertPair { - fn updated( - pair: Option<&Self>, - hostname: &Hostname, - signer: (&PKey, &X509), - applicant: &Key, - ip: BTreeSet, - ) -> Result<(Self, bool), Error> { - let mut updated = false; - let mut updated_cert = |cert: Option<&X509>, osk: PKey| -> Result { - let mut ips = BTreeSet::new(); - if let Some(cert) = cert { - ips.extend( - cert.subject_alt_names() - .iter() - .flatten() - .filter_map(|a| a.ipaddress()) - .filter_map(|a| match a.len() { - 4 => Some::(<[u8; 4]>::try_from(a).unwrap().into()), - 16 => Some::(<[u8; 16]>::try_from(a).unwrap().into()), - _ => None, - }), - ); - if cert - .not_before() - .compare(Asn1Time::days_from_now(0)?.as_ref())? - == Ordering::Less - && cert - .not_after() - .compare(Asn1Time::days_from_now(30)?.as_ref())? - == Ordering::Greater - && ips.is_superset(&ip) - { - return Ok(cert.clone()); - } - } - ips.extend(ip.iter().copied()); - updated = true; - make_leaf_cert(signer, (&osk, &SANInfo::new(&applicant, hostname, ips))) - }; - Ok(( - Self { - ed25519: updated_cert(pair.map(|c| &c.ed25519), applicant.openssl_key_ed25519())?, - nistp256: updated_cert( - pair.map(|c| &c.nistp256), - applicant.openssl_key_nistp256(), - )?, - }, - updated, - )) - } + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)] +pub struct CertPair { + #[serde(with = "crate::util::serde::pem")] + pub ed25519: X509, + #[serde(with = "crate::util::serde::pem")] + pub nistp256: X509, } pub async fn root_ca_start_time() -> Result { @@ -101,51 +180,6 @@ pub async fn root_ca_start_time() -> Result { }) } -#[derive(Debug)] -pub struct SslManager { - hostname: Hostname, - root_cert: X509, - int_key: PKey, - int_cert: X509, - cert_cache: RwLock>, -} -impl SslManager { - pub fn new(account: &AccountInfo, start_time: SystemTime) -> Result { - let int_key = generate_key()?; - let int_cert = make_int_cert( - (&account.root_ca_key, &account.root_ca_cert), - &int_key, - start_time, - )?; - Ok(Self { - hostname: account.hostname.clone(), - root_cert: account.root_ca_cert.clone(), - int_key, - int_cert, - cert_cache: RwLock::new(BTreeMap::new()), - }) - } - pub async fn with_certs(&self, key: Key, ip: IpAddr) -> Result { - let mut ips = ips().await?; - ips.insert(ip); - let (pair, updated) = CertPair::updated( - self.cert_cache.read().await.get(&key), - &self.hostname, - (&self.int_key, &self.int_cert), - &key, - ips, - )?; - if updated { - self.cert_cache - .write() - .await - .insert(key.clone(), pair.clone()); - } - - Ok(key.with_certs(pair, self.int_cert.clone(), self.root_cert.clone())) - } -} - const EC_CURVE_NAME: nid::Nid = nid::Nid::X9_62_PRIME256V1; lazy_static::lazy_static! { static ref EC_GROUP: EcGroup = EcGroup::from_curve_name(EC_CURVE_NAME).unwrap(); @@ -245,18 +279,13 @@ pub fn make_root_cert( pub fn make_int_cert( signer: (&PKey, &X509), applicant: &PKey, - start_time: SystemTime, ) -> Result { let mut builder = X509Builder::new()?; builder.set_version(CERTIFICATE_VERSION)?; - let unix_start_time = unix_time(start_time); + builder.set_not_before(signer.1.not_before())?; - let embargo = Asn1Time::from_unix(unix_start_time - 86400)?; - builder.set_not_before(&embargo)?; - - let expiration = Asn1Time::from_unix(unix_start_time + (10 * 364 * 86400))?; - builder.set_not_after(&expiration)?; + builder.set_not_after(signer.1.not_after())?; builder.set_serial_number(&*rand_serial()?)?; @@ -309,13 +338,13 @@ pub fn make_int_cert( #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] pub enum MaybeWildcard { WithWildcard(String), - WithoutWildcard(String), + WithoutWildcard(InternedString), } impl MaybeWildcard { pub fn as_str(&self) -> &str { match self { MaybeWildcard::WithWildcard(s) => s.as_str(), - MaybeWildcard::WithoutWildcard(s) => s.as_str(), + MaybeWildcard::WithoutWildcard(s) => &**s, } } } @@ -334,18 +363,16 @@ pub struct SANInfo { pub ips: BTreeSet, } impl SANInfo { - pub fn new(key: &Key, hostname: &Hostname, ips: BTreeSet) -> Self { + pub fn new(hostnames: &BTreeSet) -> Self { let mut dns = BTreeSet::new(); - if let Some((id, _)) = key.host() { - dns.insert(MaybeWildcard::WithWildcard(format!("{id}.embassy"))); - dns.insert(MaybeWildcard::WithWildcard(key.local_address().to_string())); - } else { - dns.insert(MaybeWildcard::WithoutWildcard("embassy".to_owned())); - dns.insert(MaybeWildcard::WithWildcard(hostname.local_domain_name())); - dns.insert(MaybeWildcard::WithoutWildcard(hostname.no_dot_host_name())); - dns.insert(MaybeWildcard::WithoutWildcard("localhost".to_owned())); + let mut ips = BTreeSet::new(); + for hostname in hostnames { + if let Ok(ip) = hostname.parse::() { + ips.insert(ip); + } else { + dns.insert(MaybeWildcard::WithoutWildcard(hostname.clone())); // TODO: wildcards? + } } - dns.insert(MaybeWildcard::WithWildcard(key.tor_address().to_string())); Self { dns, ips } } } @@ -443,14 +470,3 @@ pub fn make_leaf_cert( let cert = builder.build(); Ok(cert) } - -pub fn ssl() -> ParentHandler { - ParentHandler::new().subcommand("size", from_fn_async(size).with_remote_cli::()) -} - -pub async fn size(ctx: RpcContext) -> Result { - Ok(format!( - "Cert Catch size: {}", - ctx.net_controller.ssl.cert_cache.read().await.len() - )) -} diff --git a/core/startos/src/net/static_server.rs b/core/startos/src/net/static_server.rs index 68f071c79..fec881795 100644 --- a/core/startos/src/net/static_server.rs +++ b/core/startos/src/net/static_server.rs @@ -11,7 +11,6 @@ use axum::routing::{any, get, post}; use axum::Router; use digest::Digest; use futures::future::ready; -use futures::{FutureExt, TryFutureExt}; use http::header::ACCEPT_ENCODING; use http::request::Parts as RequestParts; use http::{HeaderMap, Method, StatusCode}; @@ -28,7 +27,6 @@ use crate::context::{DiagnosticContext, InstallContext, RpcContext, SetupContext use crate::core::rpc_continuations::RequestGuid; use crate::db::subscribe; use crate::hostname::Hostname; -use crate::install::PKG_PUBLIC_DIR; use crate::middleware::auth::{Auth, HasValidSession}; use crate::middleware::cors::Cors; use crate::middleware::db::SyncDb; @@ -131,8 +129,7 @@ pub fn main_ui_server_router(ctx: RpcContext) -> Router { "/ws/rpc/*path", get({ let ctx = ctx.clone(); - move |headers: HeaderMap, - x::Path(path): x::Path, + move |x::Path(path): x::Path, ws: axum::extract::ws::WebSocketUpgrade| async move { match RequestGuid::from(&path) { None => { @@ -155,7 +152,6 @@ pub fn main_ui_server_router(ctx: RpcContext) -> Router { let path = request .uri() .path() - .clone() .strip_prefix("/rest/rpc/") .unwrap_or_default(); match RequestGuid::from(&path) { diff --git a/core/startos/src/net/tor.rs b/core/startos/src/net/tor.rs index 13096dab8..171404ceb 100644 --- a/core/startos/src/net/tor.rs +++ b/core/startos/src/net/tor.rs @@ -28,13 +28,44 @@ use crate::logs::{ cli_logs_generic_follow, cli_logs_generic_nofollow, fetch_logs, follow_logs, journalctl, LogFollowResponse, LogResponse, LogSource, }; +use crate::prelude::*; use crate::util::serde::{display_serializable, HandlerExtSerde, WithIoFormat}; use crate::util::Invoke; -use crate::{Error, ErrorKind, ResultExt as _}; pub const SYSTEMD_UNIT: &str = "tor@default"; const STARTING_HEALTH_TIMEOUT: u64 = 120; // 2min +#[derive(Debug, Default, Deserialize, Serialize)] +pub struct OnionStore(BTreeMap); +impl Map for OnionStore { + type Key = OnionAddressV3; + type Value = TorSecretKeyV3; + fn key_str(key: &Self::Key) -> Result, Error> { + Ok(key.get_address_without_dot_onion()) + } +} +impl OnionStore { + pub fn new() -> Self { + Self::default() + } + pub fn insert(&mut self, key: TorSecretKeyV3) { + self.0.insert(key.public().get_onion_address(), key); + } +} +impl Model { + pub fn new_key(&mut self) -> Result { + let key = TorSecretKeyV3::generate(); + self.insert(&key.public().get_onion_address(), &key)?; + Ok(key) + } + pub fn insert_key(&mut self, key: &TorSecretKeyV3) -> Result<(), Error> { + self.insert(&key.public().get_onion_address(), &key) + } + pub fn get_key(&self, address: &OnionAddressV3) -> Result { + self.as_idx(address).or_not_found(address)?.de() + } +} + enum ErrorLogSeverity { Fatal { wipe_state: bool }, Unknown { wipe_state: bool }, @@ -208,33 +239,29 @@ impl TorController { pub async fn add( &self, key: TorSecretKeyV3, - external: u16, - target: SocketAddr, - ) -> Result, Error> { + bindings: Vec<(u16, SocketAddr)>, + ) -> Result>, Error> { let (reply, res) = oneshot::channel(); self.0 .send .send(TorCommand::AddOnion { key, - external, - target, + bindings, reply, }) - .ok() - .ok_or_else(|| Error::new(eyre!("TorControl died"), ErrorKind::Tor))?; + .map_err(|_| Error::new(eyre!("TorControl died"), ErrorKind::Tor))?; res.await - .ok() - .ok_or_else(|| Error::new(eyre!("TorControl died"), ErrorKind::Tor)) + .map_err(|_| Error::new(eyre!("TorControl died"), ErrorKind::Tor)) } pub async fn gc( &self, - key: Option, + addr: Option, external: Option, ) -> Result<(), Error> { self.0 .send - .send(TorCommand::GC { key, external }) + .send(TorCommand::GC { addr, external }) .ok() .ok_or_else(|| Error::new(eyre!("TorControl died"), ErrorKind::Tor)) } @@ -279,12 +306,11 @@ type AuthenticatedConnection = AuthenticatedConn< enum TorCommand { AddOnion { key: TorSecretKeyV3, - external: u16, - target: SocketAddr, - reply: oneshot::Sender>, + bindings: Vec<(u16, SocketAddr)>, + reply: oneshot::Sender>>, }, GC { - key: Option, + addr: Option, external: Option, }, GetInfo { @@ -302,7 +328,13 @@ async fn torctl( tor_control: SocketAddr, tor_socks: SocketAddr, recv: &mut mpsc::UnboundedReceiver, - services: &mut BTreeMap<[u8; 64], BTreeMap>>>, + services: &mut BTreeMap< + OnionAddressV3, + ( + TorSecretKeyV3, + BTreeMap>>, + ), + >, wipe_state: &AtomicBool, health_timeout: &mut Duration, ) -> Result<(), Error> { @@ -420,27 +452,32 @@ async fn torctl( match command { TorCommand::AddOnion { key, - external, - target, + bindings, reply, } => { - let mut service = if let Some(service) = services.remove(&key.as_bytes()) { + let addr = key.public().get_onion_address(); + let mut service = if let Some((_key, service)) = services.remove(&addr) { + debug_assert_eq!(key, _key); service } else { BTreeMap::new() }; - let mut binding = service.remove(&external).unwrap_or_default(); - let rc = if let Some(rc) = - Weak::upgrade(&binding.remove(&target).unwrap_or_default()) - { - rc - } else { - Arc::new(()) - }; - binding.insert(target, Arc::downgrade(&rc)); - service.insert(external, binding); - services.insert(key.as_bytes(), service); - reply.send(rc).unwrap_or_default(); + let mut rcs = Vec::with_capacity(bindings.len()); + for (external, target) in bindings { + let mut binding = service.remove(&external).unwrap_or_default(); + let rc = if let Some(rc) = + Weak::upgrade(&binding.remove(&target).unwrap_or_default()) + { + rc + } else { + Arc::new(()) + }; + binding.insert(target, Arc::downgrade(&rc)); + service.insert(external, binding); + rcs.push(rc); + } + services.insert(addr, (key, service)); + reply.send(rcs).unwrap_or_default(); } TorCommand::GetInfo { reply, .. } => { reply @@ -480,8 +517,7 @@ async fn torctl( ) .await?; - for (key, service) in std::mem::take(services) { - let key = TorSecretKeyV3::from(key); + for (addr, (key, service)) in std::mem::take(services) { let bindings = service .iter() .flat_map(|(ext, int)| { @@ -491,7 +527,7 @@ async fn torctl( }) .collect::>(); if !bindings.is_empty() { - services.insert(key.as_bytes(), service); + services.insert(addr, (key.clone(), service)); connection .add_onion_v3(&key, false, false, false, None, &mut bindings.iter()) .await?; @@ -503,31 +539,33 @@ async fn torctl( match command { TorCommand::AddOnion { key, - external, - target, + bindings, reply, } => { let mut rm_res = Ok(()); - let onion_base = key - .public() - .get_onion_address() - .get_address_without_dot_onion(); - let mut service = if let Some(service) = services.remove(&key.as_bytes()) { + let addr = key.public().get_onion_address(); + let onion_base = addr.get_address_without_dot_onion(); + let mut service = if let Some((_key, service)) = services.remove(&addr) { + debug_assert_eq!(_key, key); rm_res = connection.del_onion(&onion_base).await; service } else { BTreeMap::new() }; - let mut binding = service.remove(&external).unwrap_or_default(); - let rc = if let Some(rc) = - Weak::upgrade(&binding.remove(&target).unwrap_or_default()) - { - rc - } else { - Arc::new(()) - }; - binding.insert(target, Arc::downgrade(&rc)); - service.insert(external, binding); + let mut rcs = Vec::with_capacity(bindings.len()); + for (external, target) in bindings { + let mut binding = service.remove(&external).unwrap_or_default(); + let rc = if let Some(rc) = + Weak::upgrade(&binding.remove(&target).unwrap_or_default()) + { + rc + } else { + Arc::new(()) + }; + binding.insert(target, Arc::downgrade(&rc)); + service.insert(external, binding); + rcs.push(rc); + } let bindings = service .iter() .flat_map(|(ext, int)| { @@ -536,25 +574,21 @@ async fn torctl( .map(|(addr, _)| (*ext, SocketAddr::from(*addr))) }) .collect::>(); - services.insert(key.as_bytes(), service); - reply.send(rc).unwrap_or_default(); + services.insert(addr, (key.clone(), service)); + reply.send(rcs).unwrap_or_default(); rm_res?; connection .add_onion_v3(&key, false, false, false, None, &mut bindings.iter()) .await?; } - TorCommand::GC { key, external } => { - for key in if key.is_some() { - itertools::Either::Left(key.into_iter().map(|k| k.as_bytes())) + TorCommand::GC { addr, external } => { + for addr in if addr.is_some() { + itertools::Either::Left(addr.into_iter()) } else { itertools::Either::Right(services.keys().cloned().collect_vec().into_iter()) } { - let key = TorSecretKeyV3::from(key); - let onion_base = key - .public() - .get_onion_address() - .get_address_without_dot_onion(); - if let Some(mut service) = services.remove(&key.as_bytes()) { + if let Some((key, mut service)) = services.remove(&addr) { + let onion_base: String = addr.get_address_without_dot_onion(); for external in if external.is_some() { itertools::Either::Left(external.into_iter()) } else { @@ -583,7 +617,7 @@ async fn torctl( }) .collect::>(); if !bindings.is_empty() { - services.insert(key.as_bytes(), service); + services.insert(addr, (key.clone(), service)); } rm_res?; if !bindings.is_empty() { diff --git a/core/startos/src/net/vhost.rs b/core/startos/src/net/vhost.rs index 3d60544db..88cb759a0 100644 --- a/core/startos/src/net/vhost.rs +++ b/core/startos/src/net/vhost.rs @@ -5,7 +5,9 @@ use std::time::Duration; use color_eyre::eyre::eyre; use helpers::NonDetachingJoinHandle; +use imbl_value::InternedString; use models::ResultExt; +use serde::{Deserialize, Serialize}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{Mutex, RwLock}; use tokio_rustls::rustls::pki_types::{ @@ -16,38 +18,36 @@ use tokio_rustls::rustls::{RootCertStore, ServerConfig}; use tokio_rustls::{LazyConfigAcceptor, TlsConnector}; use tracing::instrument; -use crate::net::keys::Key; -use crate::net::ssl::SslManager; use crate::prelude::*; use crate::util::io::{BackTrackingReader, TimeoutStream}; +use crate::util::serde::MaybeUtf8String; // not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353 pub struct VHostController { - ssl: Arc, + db: PatchDb, servers: Mutex>, } impl VHostController { - pub fn new(ssl: Arc) -> Self { + pub fn new(db: PatchDb) -> Self { Self { - ssl, + db, servers: Mutex::new(BTreeMap::new()), } } #[instrument(skip_all)] pub async fn add( &self, - key: Key, hostname: Option, external: u16, target: SocketAddr, - connect_ssl: Result<(), AlpnInfo>, + connect_ssl: Result<(), AlpnInfo>, // Ok: yes, connect using ssl, pass through alpn; Err: connect tcp, use provided strategy for alpn ) -> Result, Error> { let mut writable = self.servers.lock().await; let server = if let Some(server) = writable.remove(&external) { server } else { - VHostServer::new(external, self.ssl.clone()).await? + VHostServer::new(external, self.db.clone()).await? }; let rc = server .add( @@ -55,7 +55,6 @@ impl VHostController { TargetInfo { addr: target, connect_ssl, - key, }, ) .await; @@ -79,13 +78,18 @@ impl VHostController { struct TargetInfo { addr: SocketAddr, connect_ssl: Result<(), AlpnInfo>, - key: Key, } -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] pub enum AlpnInfo { Reflect, - Specified(Vec>), + Specified(Vec), +} +impl Default for AlpnInfo { + fn default() -> Self { + Self::Reflect + } } struct VHostServer { @@ -94,7 +98,7 @@ struct VHostServer { } impl VHostServer { #[instrument(skip_all)] - async fn new(port: u16, ssl: Arc) -> Result { + async fn new(port: u16, db: PatchDb) -> Result { // check if port allowed let listener = TcpListener::bind(SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), port)) .await @@ -105,13 +109,13 @@ impl VHostServer { _thread: tokio::spawn(async move { loop { match listener.accept().await { - Ok((stream, _)) => { + Ok((stream, sock_addr)) => { let stream = Box::pin(TimeoutStream::new(stream, Duration::from_secs(300))); let mut stream = BackTrackingReader::new(stream); stream.start_buffering(); let mapping = mapping.clone(); - let ssl = ssl.clone(); + let db = db.clone(); tokio::spawn(async move { if let Err(e) = async { let mid = match LazyConfigAcceptor::new( @@ -167,6 +171,7 @@ impl VHostServer { .find(|(_, rc)| rc.strong_count() > 0) .or_else(|| { if target_name + .as_ref() .map(|s| s.parse::().is_ok()) .unwrap_or(true) { @@ -184,8 +189,22 @@ impl VHostServer { if let Some(target) = target { let mut tcp_stream = TcpStream::connect(target.addr).await?; - let key = - ssl.with_certs(target.key, target.addr.ip()).await?; + let hostnames = target_name + .as_ref() + .into_iter() + .map(InternedString::intern) + .chain(std::iter::once(InternedString::from_display( + &sock_addr.ip(), + ))) + .collect(); + let key = db + .mutate(|v| { + v.as_private_mut() + .as_key_store_mut() + .as_local_certs_mut() + .cert_for(&hostnames) + }) + .await?; let cfg = ServerConfig::builder() .with_no_client_auth(); let mut cfg = @@ -202,8 +221,9 @@ impl VHostServer { }) .collect::>()?, PrivateKeyDer::from(PrivatePkcs8KeyDer::from( - key.key() - .openssl_key_ed25519() + key.leaf + .keys + .ed25519 .private_key_to_pkcs8()?, )), ) @@ -218,8 +238,9 @@ impl VHostServer { }) .collect::>()?, PrivateKeyDer::from(PrivatePkcs8KeyDer::from( - key.key() - .openssl_key_nistp256() + key.leaf + .keys + .nistp256 .private_key_to_pkcs8()?, )), ) @@ -233,7 +254,7 @@ impl VHostServer { let mut store = RootCertStore::empty(); store.add( CertificateDer::from( - key.root_ca().to_der()?, + key.root.to_der()?, ), ).with_kind(crate::ErrorKind::OpenSsl)?; store @@ -249,9 +270,9 @@ impl VHostServer { let mut target_stream = TlsConnector::from(Arc::new(client_cfg)) .connect_with( - ServerName::try_from( - key.key().internal_address(), - ).with_kind(crate::ErrorKind::OpenSsl)?, + ServerName::IpAddress( + target.addr.ip().into(), + ), tcp_stream, |conn| { cfg.alpn_protocols.extend( @@ -302,7 +323,7 @@ impl VHostServer { .await } Err(AlpnInfo::Specified(alpn)) => { - cfg.alpn_protocols = alpn; + cfg.alpn_protocols = alpn.into_iter().map(|a| a.0).collect(); let mut tls_stream = match mid.into_stream(Arc::new(cfg)).await { Ok(a) => a, diff --git a/core/startos/src/notifications.rs b/core/startos/src/notifications.rs index f16eab176..f696b27b7 100644 --- a/core/startos/src/notifications.rs +++ b/core/startos/src/notifications.rs @@ -1,24 +1,23 @@ -use std::collections::HashMap; +use std::collections::BTreeMap; use std::fmt; use std::str::FromStr; -use chrono::{DateTime, TimeZone, Utc}; +use chrono::{DateTime, Utc}; use clap::builder::ValueParserFactory; use clap::Parser; use color_eyre::eyre::eyre; +use imbl_value::InternedString; use models::PackageId; use rpc_toolkit::{command, from_fn_async, HandlerExt, ParentHandler}; use serde::{Deserialize, Serialize}; -use sqlx::PgPool; -use tokio::sync::Mutex; use tracing::instrument; use crate::backup::BackupReport; use crate::context::{CliContext, RpcContext}; +use crate::db::model::DatabaseModel; use crate::prelude::*; use crate::util::clap::FromStrParser; use crate::util::serde::HandlerExtSerde; -use crate::{Error, ErrorKind, ResultExt}; // #[command(subcommands(list, delete, delete_before, create))] pub fn notification() -> ParentHandler { @@ -53,132 +52,102 @@ pub fn notification() -> ParentHandler { #[serde(rename_all = "kebab-case")] #[command(rename_all = "kebab-case")] pub struct ListParams { - before: Option, - - limit: Option, + before: Option, + limit: Option, } // #[command(display(display_serializable))] #[instrument(skip_all)] pub async fn list( ctx: RpcContext, ListParams { before, limit }: ListParams, -) -> Result, Error> { - let limit = limit.unwrap_or(40); - match before { - None => { - let records = sqlx::query!( - "SELECT id, package_id, created_at, code, level, title, message, data FROM notifications ORDER BY id DESC LIMIT $1", - limit as i64 - ).fetch_all(&ctx.secret_store).await?; - let notifs = records - .into_iter() - .map(|r| { - Ok(Notification { - id: r.id as u32, - package_id: r.package_id.and_then(|p| p.parse().ok()), - created_at: Utc.from_utc_datetime(&r.created_at), - code: r.code as u32, - level: match r.level.parse::() { - Ok(a) => a, - Err(e) => return Err(e.into()), - }, - title: r.title, - message: r.message, - data: match r.data { - None => serde_json::Value::Null, - Some(v) => match v.parse::() { - Ok(a) => a, - Err(e) => { - return Err(Error::new( - eyre!("Invalid Notification Data: {}", e), - ErrorKind::ParseDbField, - )) - } - }, - }, - }) - }) - .collect::, Error>>()?; - - ctx.db - .mutate(|d| { - d.as_public_mut() +) -> Result, Error> { + ctx.db + .mutate(|db| { + let limit = limit.unwrap_or(40); + match before { + None => { + let records = db + .as_private() + .as_notifications() + .as_entries()? + .into_iter() + .take(limit); + let notifs = records + .into_iter() + .map(|(id, notification)| { + Ok(NotificationWithId { + id, + notification: notification.de()?, + }) + }) + .collect::, Error>>()?; + db.as_public_mut() .as_server_info_mut() .as_unread_notification_count_mut() - .ser(&0) - }) - .await?; - Ok(notifs) - } - Some(before) => { - let records = sqlx::query!( - "SELECT id, package_id, created_at, code, level, title, message, data FROM notifications WHERE id < $1 ORDER BY id DESC LIMIT $2", - before, - limit as i64 - ).fetch_all(&ctx.secret_store).await?; - let res = records - .into_iter() - .map(|r| { - Ok(Notification { - id: r.id as u32, - package_id: r.package_id.and_then(|p| p.parse().ok()), - created_at: Utc.from_utc_datetime(&r.created_at), - code: r.code as u32, - level: match r.level.parse::() { - Ok(a) => a, - Err(e) => return Err(e.into()), - }, - title: r.title, - message: r.message, - data: match r.data { - None => serde_json::Value::Null, - Some(v) => match v.parse::() { - Ok(a) => a, - Err(e) => { - return Err(Error::new( - eyre!("Invalid Notification Data: {}", e), - ErrorKind::ParseDbField, - )) - } - }, - }, - }) - }) - .collect::, Error>>()?; - Ok(res) - } - } + .ser(&0)?; + Ok(notifs) + } + Some(before) => { + let records = db + .as_private() + .as_notifications() + .as_entries()? + .into_iter() + .filter(|(id, _)| *id < before) + .take(limit); + records + .into_iter() + .map(|(id, notification)| { + Ok(NotificationWithId { + id, + notification: notification.de()?, + }) + }) + .collect() + } + } + }) + .await } #[derive(Deserialize, Serialize, Parser)] #[serde(rename_all = "kebab-case")] #[command(rename_all = "kebab-case")] pub struct DeleteParams { - id: i32, + id: u32, } pub async fn delete(ctx: RpcContext, DeleteParams { id }: DeleteParams) -> Result<(), Error> { - sqlx::query!("DELETE FROM notifications WHERE id = $1", id) - .execute(&ctx.secret_store) - .await?; - Ok(()) + ctx.db + .mutate(|db| { + db.as_private_mut().as_notifications_mut().remove(&id)?; + Ok(()) + }) + .await } #[derive(Deserialize, Serialize, Parser)] #[serde(rename_all = "kebab-case")] #[command(rename_all = "kebab-case")] pub struct DeleteBeforeParams { - before: i32, + before: u32, } pub async fn delete_before( ctx: RpcContext, DeleteBeforeParams { before }: DeleteBeforeParams, ) -> Result<(), Error> { - sqlx::query!("DELETE FROM notifications WHERE id < $1", before) - .execute(&ctx.secret_store) - .await?; - Ok(()) + ctx.db + .mutate(|db| { + for id in db.as_private().as_notifications().keys()? { + if id < before { + db.as_private_mut().as_notifications_mut().remove(&id)?; + } + } + Ok(()) + }) + .await } + #[derive(Deserialize, Serialize, Parser)] #[serde(rename_all = "kebab-case")] #[command(rename_all = "kebab-case")] @@ -198,8 +167,8 @@ pub async fn create( message, }: CreateParams, ) -> Result<(), Error> { - ctx.notification_manager - .notify(ctx.db.clone(), package, level, title, message, (), None) + ctx.db + .mutate(|db| notify(db, package, level, title, message, ())) .await } @@ -254,120 +223,95 @@ impl fmt::Display for InvalidNotificationLevel { write!(f, "Invalid Notification Level: {}", self.0) } } -#[derive(Debug, serde::Serialize, serde::Deserialize)] + +#[derive(Debug, Default, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub struct Notifications(pub BTreeMap); +impl Notifications { + pub fn new() -> Self { + Self(BTreeMap::new()) + } +} +impl Map for Notifications { + type Key = u32; + type Value = Notification; + fn key_str(key: &Self::Key) -> Result, Error> { + Self::key_string(key) + } + fn key_string(key: &Self::Key) -> Result { + Ok(InternedString::from_display(key)) + } +} + +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] pub struct Notification { - id: u32, package_id: Option, created_at: DateTime, code: u32, level: NotificationLevel, title: String, message: String, - data: serde_json::Value, + data: Value, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub struct NotificationWithId { + id: u32, + #[serde(flatten)] + notification: Notification, } pub trait NotificationType: serde::Serialize + for<'de> serde::Deserialize<'de> + std::fmt::Debug { - const CODE: i32; + const CODE: u32; } impl NotificationType for () { - const CODE: i32 = 0; + const CODE: u32 = 0; } impl NotificationType for BackupReport { - const CODE: i32 = 1; + const CODE: u32 = 1; } -pub struct NotificationManager { - sqlite: PgPool, - cache: Mutex, NotificationLevel, String), i64>>, -} -impl NotificationManager { - pub fn new(sqlite: PgPool) -> Self { - NotificationManager { - sqlite, - cache: Mutex::new(HashMap::new()), - } - } - #[instrument(skip(db, subtype, self))] - pub async fn notify( - &self, - db: PatchDb, - package_id: Option, - level: NotificationLevel, - title: String, - message: String, - subtype: T, - debounce_interval: Option, - ) -> Result<(), Error> { - let peek = db.peek().await; - if !self - .should_notify(&package_id, &level, &title, debounce_interval) - .await - { - return Ok(()); - } - let mut count = peek - .as_public() - .as_server_info() - .as_unread_notification_count() - .de()?; - let sql_package_id = package_id.as_ref().map(|p| &**p); - let sql_code = T::CODE; - let sql_level = format!("{}", level); - let sql_data = - serde_json::to_string(&subtype).with_kind(crate::ErrorKind::Serialization)?; - sqlx::query!( - "INSERT INTO notifications (package_id, code, level, title, message, data) VALUES ($1, $2, $3, $4, $5, $6)", - sql_package_id, - sql_code as i32, - sql_level, - title, - message, - sql_data - ).execute(&self.sqlite).await?; - count += 1; - db.mutate(|db| { - db.as_public_mut() - .as_server_info_mut() - .as_unread_notification_count_mut() - .ser(&count) - }) - .await - } - async fn should_notify( - &self, - package_id: &Option, - level: &NotificationLevel, - title: &String, - debounce_interval: Option, - ) -> bool { - let mut guard = self.cache.lock().await; - let k = (package_id.clone(), level.clone(), title.clone()); - let v = (*guard).get(&k); - match v { - None => { - (*guard).insert(k, Utc::now().timestamp()); - true - } - Some(last_issued) => match debounce_interval { - None => { - (*guard).insert(k, Utc::now().timestamp()); - true - } - Some(interval) => { - if last_issued + interval as i64 > Utc::now().timestamp() { - false - } else { - (*guard).insert(k, Utc::now().timestamp()); - true - } - } - }, - } - } +#[instrument(skip(subtype, db))] +pub fn notify( + db: &mut DatabaseModel, + package_id: Option, + level: NotificationLevel, + title: String, + message: String, + subtype: T, +) -> Result<(), Error> { + let data = to_value(&subtype)?; + db.as_public_mut() + .as_server_info_mut() + .as_unread_notification_count_mut() + .mutate(|c| { + *c += 1; + Ok(()) + })?; + let id = db + .as_private() + .as_notifications() + .keys()? + .into_iter() + .max() + .map_or(0, |id| id + 1); + db.as_private_mut().as_notifications_mut().insert( + &id, + &Notification { + package_id, + created_at: Utc::now(), + code: T::CODE, + level, + title, + message, + data, + }, + ) } #[test] diff --git a/core/startos/src/registry/admin.rs b/core/startos/src/registry/admin.rs index 9f0033e96..95cfcec8f 100644 --- a/core/startos/src/registry/admin.rs +++ b/core/startos/src/registry/admin.rs @@ -133,7 +133,7 @@ pub async fn publish( .with_prefix("[1/3]") .with_message("Querying s9pk"); pb.enable_steady_tick(Duration::from_millis(200)); - let mut s9pk = S9pk::open(&path, None).await?; + let s9pk = S9pk::open(&path, None).await?; let m = s9pk.as_manifest().clone(); pb.set_style(plain_line_style.clone()); pb.abandon(); @@ -144,7 +144,7 @@ pub async fn publish( .with_prefix("[1/3]") .with_message("Verifying s9pk"); pb.enable_steady_tick(Duration::from_millis(200)); - let mut s9pk = S9pk::open(&path, None).await?; + let s9pk = S9pk::open(&path, None).await?; // s9pk.validate().await?; todo!(); let m = s9pk.as_manifest().clone(); diff --git a/core/startos/src/s9pk/merkle_archive/mod.rs b/core/startos/src/s9pk/merkle_archive/mod.rs index abddb3c1e..afd00032a 100644 --- a/core/startos/src/s9pk/merkle_archive/mod.rs +++ b/core/startos/src/s9pk/merkle_archive/mod.rs @@ -1,7 +1,5 @@ use std::path::Path; -use std::sync::Arc; -use ed25519::signature::Keypair; use ed25519_dalek::{Signature, SigningKey, VerifyingKey}; use tokio::io::AsyncRead; diff --git a/core/startos/src/s9pk/merkle_archive/source/multi_cursor_file.rs b/core/startos/src/s9pk/merkle_archive/source/multi_cursor_file.rs index afb808471..7add68e6f 100644 --- a/core/startos/src/s9pk/merkle_archive/source/multi_cursor_file.rs +++ b/core/startos/src/s9pk/merkle_archive/source/multi_cursor_file.rs @@ -1,7 +1,7 @@ -use std::os::fd::{AsRawFd, FromRawFd, RawFd}; +use std::io::SeekFrom; +use std::os::fd::{AsRawFd, RawFd}; use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::{borrow::Borrow, io::SeekFrom}; use tokio::fs::File; use tokio::io::{AsyncRead, AsyncReadExt}; diff --git a/core/startos/src/service/config.rs b/core/startos/src/service/config.rs index 06dc8bd55..e1294a465 100644 --- a/core/startos/src/service/config.rs +++ b/core/startos/src/service/config.rs @@ -1,6 +1,4 @@ -use std::collections::BTreeMap; - -use models::{ActionId, PackageId, ProcedureName}; +use models::ProcedureName; use crate::config::ConfigureContext; use crate::prelude::*; diff --git a/core/startos/src/service/mod.rs b/core/startos/src/service/mod.rs index 3335be1ca..1efb116d5 100644 --- a/core/startos/src/service/mod.rs +++ b/core/startos/src/service/mod.rs @@ -1,5 +1,5 @@ +use std::sync::Arc; use std::time::Duration; -use std::{ops::Deref, sync::Arc}; use chrono::{DateTime, Utc}; use clap::Parser; @@ -10,25 +10,25 @@ use persistent_container::PersistentContainer; use rpc_toolkit::{from_fn_async, CallRemoteHandler, Empty, Handler, HandlerArgs}; use serde::{Deserialize, Serialize}; use start_stop::StartStop; -use tokio::sync::{watch, Notify}; +use tokio::sync::Notify; use crate::action::ActionResult; use crate::config::action::ConfigRes; use crate::context::{CliContext, RpcContext}; use crate::core::rpc_continuations::RequestGuid; use crate::db::model::{ - CurrentDependencies, CurrentDependents, InstalledPackageInfo, PackageDataEntry, - PackageDataEntryInstalled, PackageDataEntryMatchModel, StaticFiles, + InstalledPackageInfo, PackageDataEntry, PackageDataEntryInstalled, PackageDataEntryMatchModel, + StaticFiles, }; use crate::disk::mount::guard::GenericMountGuard; use crate::install::PKG_ARCHIVE_DIR; use crate::prelude::*; -use crate::progress::{self, NamedProgress, Progress}; +use crate::progress::{NamedProgress, Progress}; use crate::s9pk::S9pk; use crate::service::service_map::InstallProgressHandles; -use crate::service::transition::{TempDesiredState, TransitionKind, TransitionState}; +use crate::service::transition::TransitionKind; use crate::status::health_check::HealthCheckResult; -use crate::status::{DependencyConfigErrors, MainStatus, Status}; +use crate::status::{MainStatus, Status}; use crate::util::actor::{Actor, BackgroundJobs, SimpleActor}; use crate::volume::data_dir; @@ -289,6 +289,7 @@ impl Service { marketplace_url: None, // TODO manifest: manifest.clone(), last_backup: None, // TODO + hosts: Default::default(), // TODO store: Value::Null, // TODO store_exposed_dependents: Default::default(), // TODO store_exposed_ui: Default::default(), // TODO diff --git a/core/startos/src/service/persistent_container.rs b/core/startos/src/service/persistent_container.rs index 28067cdd8..523b5f19a 100644 --- a/core/startos/src/service/persistent_container.rs +++ b/core/startos/src/service/persistent_container.rs @@ -15,19 +15,18 @@ use tokio::process::Command; use tokio::sync::{oneshot, watch, Mutex, OnceCell}; use tracing::instrument; -use super::{ - service_effect_handler::{service_effect_handler, EffectContext}, - transition::{TempDesiredState, TransitionKind}, -}; -use super::{transition::TransitionState, ServiceActorSeed}; +use super::service_effect_handler::{service_effect_handler, EffectContext}; +use super::transition::{TransitionKind, TransitionState}; +use super::ServiceActorSeed; use crate::context::RpcContext; use crate::disk::mount::filesystem::bind::Bind; use crate::disk::mount::filesystem::idmapped::IdMapped; use crate::disk::mount::filesystem::loop_dev::LoopDev; use crate::disk::mount::filesystem::overlayfs::OverlayGuard; use crate::disk::mount::filesystem::{MountType, ReadOnly}; -use crate::disk::mount::guard::{GenericMountGuard, MountGuard}; +use crate::disk::mount::guard::MountGuard; use crate::lxc::{LxcConfig, LxcContainer, HOST_RPC_SERVER_SOCKET}; +use crate::net::net_controller::NetService; use crate::prelude::*; use crate::s9pk::merkle_archive::source::FileSource; use crate::s9pk::S9pk; @@ -94,6 +93,7 @@ pub struct PersistentContainer { assets: BTreeMap, pub(super) overlays: Arc>>, pub(super) state: Arc>, + pub(super) net_service: Mutex, } impl PersistentContainer { @@ -178,6 +178,10 @@ impl PersistentContainer { .await?; } } + let net_service = ctx + .net_controller + .create_service(s9pk.as_manifest().id.clone(), lxc_container.ip()) + .await?; Ok(Self { s9pk, lxc_container: OnceCell::new_with(Some(lxc_container)), @@ -189,6 +193,7 @@ impl PersistentContainer { assets, overlays: Arc::new(Mutex::new(BTreeMap::new())), state: Arc::new(watch::channel(ServiceState::new(start)).0), + net_service: Mutex::new(net_service), }) } diff --git a/core/startos/src/service/rpc.rs b/core/startos/src/service/rpc.rs index 05e6dcfab..6823a7189 100644 --- a/core/startos/src/service/rpc.rs +++ b/core/startos/src/service/rpc.rs @@ -2,7 +2,7 @@ use std::time::Duration; use imbl_value::Value; use models::ProcedureName; -use rpc_toolkit::yajrc::{RpcError, RpcMethod}; +use rpc_toolkit::yajrc::RpcMethod; use rpc_toolkit::Empty; use crate::prelude::*; diff --git a/core/startos/src/service/service_effect_handler.rs b/core/startos/src/service/service_effect_handler.rs index 3978b64c7..ea2228e81 100644 --- a/core/startos/src/service/service_effect_handler.rs +++ b/core/startos/src/service/service_effect_handler.rs @@ -1,11 +1,10 @@ +use std::ffi::OsString; use std::os::unix::process::CommandExt; use std::path::{Path, PathBuf}; use std::str::FromStr; use std::sync::{Arc, Weak}; -use std::{ffi::OsString, time::Instant}; -use chrono::Utc; -use clap::builder::{TypedValueParser, ValueParserFactory}; +use clap::builder::ValueParserFactory; use clap::Parser; use imbl_value::{json, InternedString}; use models::{ActionId, HealthCheckId, ImageId, PackageId}; @@ -13,19 +12,18 @@ use patch_db::json_ptr::JsonPointer; use rpc_toolkit::{from_fn, from_fn_async, AnyContext, Context, Empty, HandlerExt, ParentHandler}; use tokio::process::Command; +use crate::db::model::ExposedUI; +use crate::disk::mount::filesystem::idmapped::IdMapped; use crate::disk::mount::filesystem::loop_dev::LoopDev; use crate::disk::mount::filesystem::overlayfs::OverlayGuard; use crate::prelude::*; use crate::s9pk::rpc::SKIP_ENV; use crate::service::cli::ContainerCliContext; -use crate::service::start_stop::StartStop; use crate::service::ServiceActorSeed; -use crate::status::health_check::HealthCheckResult; +use crate::status::health_check::{HealthCheckResult, HealthCheckString}; use crate::status::MainStatus; use crate::util::clap::FromStrParser; use crate::util::{new_guid, Invoke}; -use crate::{db::model::ExposedUI, service::RunningStatus}; -use crate::{disk::mount::filesystem::idmapped::IdMapped, status::health_check::HealthCheckString}; use crate::{echo, ARCH}; #[derive(Clone)] diff --git a/core/startos/src/service/service_map.rs b/core/startos/src/service/service_map.rs index 7ac555aff..f555be531 100644 --- a/core/startos/src/service/service_map.rs +++ b/core/startos/src/service/service_map.rs @@ -12,12 +12,12 @@ use tracing::instrument; use crate::context::RpcContext; use crate::db::model::{ - InstalledPackageInfo, PackageDataEntry, PackageDataEntryInstalled, PackageDataEntryInstalling, + PackageDataEntry, PackageDataEntryInstalled, PackageDataEntryInstalling, PackageDataEntryRestoring, PackageDataEntryUpdating, StaticFiles, }; use crate::disk::mount::guard::GenericMountGuard; use crate::install::PKG_ARCHIVE_DIR; -use crate::notifications::NotificationLevel; +use crate::notifications::{notify, NotificationLevel}; use crate::prelude::*; use crate::progress::{ FullProgressTracker, FullProgressTrackerHandle, PhaseProgressTrackerHandle, @@ -370,17 +370,19 @@ impl ServiceReloadInfo { .load(&self.ctx, &self.id, LoadDisposition::Undo) .await?; if let Some(error) = error { + let error_string = error.to_string(); self.ctx - .notification_manager - .notify( - self.ctx.db.clone(), - Some(self.id.clone()), - NotificationLevel::Error, - format!("{} Failed", self.operation), - error.to_string(), - (), - None, - ) + .db + .mutate(|db| { + notify( + db, + Some(self.id.clone()), + NotificationLevel::Error, + format!("{} Failed", self.operation), + error_string, + (), + ) + }) .await?; } Ok(()) diff --git a/core/startos/src/service/transition/mod.rs b/core/startos/src/service/transition/mod.rs index cd7979cae..af62ccc1c 100644 --- a/core/startos/src/service/transition/mod.rs +++ b/core/startos/src/service/transition/mod.rs @@ -1,15 +1,13 @@ use std::sync::Arc; -use std::{fmt::Display, ops::Deref}; use futures::{Future, FutureExt}; use tokio::sync::watch; +use super::persistent_container::ServiceState; use crate::service::start_stop::StartStop; use crate::util::actor::BackgroundJobs; use crate::util::future::{CancellationHandle, RemoteCancellable}; -use super::persistent_container::ServiceState; - pub mod backup; pub mod restart; diff --git a/core/startos/src/setup.rs b/core/startos/src/setup.rs index 9c8d54db3..8275f2d61 100644 --- a/core/startos/src/setup.rs +++ b/core/startos/src/setup.rs @@ -5,10 +5,10 @@ use std::time::Duration; use color_eyre::eyre::eyre; use josekit::jwk::Jwk; use openssl::x509::X509; +use patch_db::json_ptr::ROOT; use rpc_toolkit::yajrc::RpcError; use rpc_toolkit::{from_fn_async, HandlerExt, ParentHandler}; use serde::{Deserialize, Serialize}; -use sqlx::Connection; use tokio::fs::File; use tokio::io::AsyncWriteExt; use tokio::try_join; @@ -20,6 +20,7 @@ use crate::backup::restore::recover_full_embassy; use crate::backup::target::BackupTargetFS; use crate::context::setup::SetupResult; use crate::context::SetupContext; +use crate::db::model::Database; use crate::disk::fsck::RepairStrategy; use crate::disk::main::DEFAULT_PASSWORD; use crate::disk::mount::filesystem::cifs::Cifs; @@ -74,29 +75,26 @@ async fn setup_init( ctx: &SetupContext, password: Option, ) -> Result<(Hostname, OnionAddressV3, X509), Error> { - let InitResult { secret_store, db } = init(&ctx.config).await?; - let mut secrets_handle = secret_store.acquire().await?; - let mut secrets_tx = secrets_handle.begin().await?; + let InitResult { db } = init(&ctx.config).await?; - let mut account = AccountInfo::load(secrets_tx.as_mut()).await?; - - if let Some(password) = password { - account.set_password(&password)?; - account.save(secrets_tx.as_mut()).await?; - db.mutate(|m| { + let account = db + .mutate(|m| { + let mut account = AccountInfo::load(m)?; + if let Some(password) = password { + account.set_password(&password)?; + } + account.save(m)?; m.as_public_mut() .as_server_info_mut() .as_password_hash_mut() - .ser(&account.password) + .ser(&account.password)?; + Ok(account) }) .await?; - } - - secrets_tx.commit().await?; Ok(( account.hostname, - account.key.tor_address(), + account.tor_key.public().get_onion_address(), account.root_ca_cert, )) } @@ -419,15 +417,13 @@ async fn fresh_setup( embassy_password: &str, ) -> Result<(Hostname, OnionAddressV3, X509), Error> { let account = AccountInfo::new(embassy_password, root_ca_start_time().await?)?; - let sqlite_pool = ctx.secret_store().await?; - account.save(&sqlite_pool).await?; - sqlite_pool.close().await; - let InitResult { secret_store, .. } = init(&ctx.config).await?; - secret_store.close().await; + let db = ctx.db().await?; + db.put(&ROOT, &Database::init(&account)?).await?; + init(&ctx.config).await?; Ok(( - account.hostname.clone(), - account.key.tor_address(), - account.root_ca_cert.clone(), + account.hostname, + account.tor_key.public().get_onion_address(), + account.root_ca_cert, )) } diff --git a/core/startos/src/ssh.rs b/core/startos/src/ssh.rs index d762b63a0..8965c7edd 100644 --- a/core/startos/src/ssh.rs +++ b/core/startos/src/ssh.rs @@ -1,28 +1,46 @@ +use std::collections::BTreeMap; use std::path::Path; -use chrono::Utc; use clap::builder::ValueParserFactory; use clap::Parser; use color_eyre::eyre::eyre; +use imbl_value::InternedString; use rpc_toolkit::{command, from_fn_async, AnyContext, Empty, HandlerExt, ParentHandler}; use serde::{Deserialize, Serialize}; -use sqlx::{Pool, Postgres}; use tracing::instrument; use crate::context::{CliContext, RpcContext}; +use crate::prelude::*; use crate::util::clap::FromStrParser; use crate::util::serde::{display_serializable, HandlerExtSerde, WithIoFormat}; -use crate::{Error, ErrorKind}; static SSH_AUTHORIZED_KEYS_FILE: &str = "/home/start9/.ssh/authorized_keys"; #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct PubKey( +pub struct SshKeys(BTreeMap>); +impl SshKeys { + pub fn new() -> Self { + Self(BTreeMap::new()) + } +} +impl Map for SshKeys { + type Key = InternedString; + type Value = WithTimeData; + fn key_str(key: &Self::Key) -> Result, Error> { + Ok(key) + } + fn key_string(key: &Self::Key) -> Result { + Ok(key.clone()) + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct SshPubKey( #[serde(serialize_with = "crate::util::serde::serialize_display")] #[serde(deserialize_with = "crate::util::serde::deserialize_from_str")] openssh_keys::PublicKey, ); -impl ValueParserFactory for PubKey { +impl ValueParserFactory for SshPubKey { type Parser = FromStrParser; fn value_parser() -> Self::Parser { FromStrParser::new() @@ -33,7 +51,7 @@ impl ValueParserFactory for PubKey { #[serde(rename_all = "kebab-case")] pub struct SshKeyResponse { pub alg: String, - pub fingerprint: String, + pub fingerprint: InternedString, pub hostname: String, pub created_at: String, } @@ -47,10 +65,10 @@ impl std::fmt::Display for SshKeyResponse { } } -impl std::str::FromStr for PubKey { +impl std::str::FromStr for SshPubKey { type Err = Error; fn from_str(s: &str) -> Result { - s.parse().map(|pk| PubKey(pk)).map_err(|e| Error { + s.parse().map(|pk| SshPubKey(pk)).map_err(|e| Error { source: e.into(), kind: crate::ErrorKind::ParseSshKey, revision: None, @@ -88,49 +106,34 @@ pub fn ssh() -> ParentHandler { #[serde(rename_all = "kebab-case")] #[command(rename_all = "kebab-case")] pub struct AddParams { - key: PubKey, + key: SshPubKey, } #[instrument(skip_all)] pub async fn add(ctx: RpcContext, AddParams { key }: AddParams) -> Result { - let pool = &ctx.secret_store; - // check fingerprint for duplicates - let fp = key.0.fingerprint_md5(); - match sqlx::query!("SELECT * FROM ssh_keys WHERE fingerprint = $1", fp) - .fetch_optional(pool) - .await? - { - None => { - // if no duplicates, insert into DB - let raw_key = format!("{}", key.0); - let created_at = Utc::now().to_rfc3339(); - sqlx::query!( - "INSERT INTO ssh_keys (fingerprint, openssh_pubkey, created_at) VALUES ($1, $2, $3)", - fp, - raw_key, - created_at - ) - .execute(pool) - .await?; - // insert into live key file, for now we actually do a wholesale replacement of the keys file, for maximum - // consistency - sync_keys_from_db(pool, Path::new(SSH_AUTHORIZED_KEYS_FILE)).await?; + let mut key = WithTimeData::new(key); + let fingerprint = InternedString::intern(key.0.fingerprint_md5()); + ctx.db + .mutate(move |m| { + m.as_private_mut() + .as_ssh_pubkeys_mut() + .insert(&fingerprint, &key)?; + Ok(SshKeyResponse { alg: key.0.keytype().to_owned(), - fingerprint: fp, - hostname: key.0.comment.unwrap_or(String::new()).to_owned(), - created_at, + fingerprint, + hostname: key.0.comment.take().unwrap_or_default(), + created_at: key.created_at.to_rfc3339(), }) - } - Some(_) => Err(Error::new(eyre!("Duplicate ssh key"), ErrorKind::Duplicate)), - } + }) + .await } #[derive(Deserialize, Serialize, Parser)] #[serde(rename_all = "kebab-case")] #[command(rename_all = "kebab-case")] pub struct DeleteParams { - fingerprint: String, + fingerprint: InternedString, } #[instrument(skip_all)] @@ -138,25 +141,22 @@ pub async fn delete( ctx: RpcContext, DeleteParams { fingerprint }: DeleteParams, ) -> Result<(), Error> { - let pool = &ctx.secret_store; - // check if fingerprint is in DB - // if in DB, remove it from DB - let n = sqlx::query!("DELETE FROM ssh_keys WHERE fingerprint = $1", fingerprint) - .execute(pool) - .await? - .rows_affected(); - // if not in DB, Err404 - if n == 0 { - Err(Error { - source: color_eyre::eyre::eyre!("SSH Key Not Found"), - kind: crate::error::ErrorKind::NotFound, - revision: None, + let keys = ctx + .db + .mutate(|m| { + let keys_ref = m.as_private_mut().as_ssh_pubkeys_mut(); + if keys_ref.remove(&fingerprint)?.is_some() { + keys_ref.de() + } else { + Err(Error { + source: color_eyre::eyre::eyre!("SSH Key Not Found"), + kind: crate::error::ErrorKind::NotFound, + revision: None, + }) + } }) - } else { - // AND overlay key file - sync_keys_from_db(pool, Path::new(SSH_AUTHORIZED_KEYS_FILE)).await?; - Ok(()) - } + .await?; + sync_keys(&keys, SSH_AUTHORIZED_KEYS_FILE).await } fn display_all_ssh_keys(params: WithIoFormat, result: Vec) { @@ -186,43 +186,31 @@ fn display_all_ssh_keys(params: WithIoFormat, result: Vec } #[instrument(skip_all)] -pub async fn list(ctx: RpcContext, _: Empty) -> Result, Error> { - let pool = &ctx.secret_store; - // list keys in DB and return them - let entries = sqlx::query!("SELECT fingerprint, openssh_pubkey, created_at FROM ssh_keys") - .fetch_all(pool) - .await?; - Ok(entries +pub async fn list(ctx: RpcContext) -> Result, Error> { + ctx.db + .peek() + .await + .into_private() + .into_ssh_pubkeys() + .into_entries()? .into_iter() - .map(|r| { - let k = PubKey(r.openssh_pubkey.parse().unwrap()).0; - let alg = k.keytype().to_owned(); - let fingerprint = k.fingerprint_md5(); - let hostname = k.comment.unwrap_or("".to_owned()); - let created_at = r.created_at; - SshKeyResponse { - alg, + .map(|(fingerprint, key)| { + let mut key = key.de()?; + Ok(SshKeyResponse { + alg: key.0.keytype().to_owned(), fingerprint, - hostname, - created_at, - } + hostname: key.0.comment.take().unwrap_or_default(), + created_at: key.created_at.to_rfc3339(), + }) }) - .collect()) + .collect() } #[instrument(skip_all)] -pub async fn sync_keys_from_db>( - pool: &Pool, - dest: P, -) -> Result<(), Error> { +pub async fn sync_keys>(keys: &SshKeys, dest: P) -> Result<(), Error> { + use tokio::io::AsyncWriteExt; + let dest = dest.as_ref(); - let keys = sqlx::query!("SELECT openssh_pubkey FROM ssh_keys") - .fetch_all(pool) - .await?; - let contents: String = keys - .into_iter() - .map(|k| format!("{}\n", k.openssh_pubkey)) - .collect(); let ssh_dir = dest.parent().ok_or_else(|| { Error::new( eyre!("SSH Key File cannot be \"/\""), @@ -232,5 +220,10 @@ pub async fn sync_keys_from_db>( if tokio::fs::metadata(ssh_dir).await.is_err() { tokio::fs::create_dir_all(ssh_dir).await?; } - std::fs::write(dest, contents).map_err(|e| e.into()) + let mut f = tokio::fs::File::create(dest).await?; + for key in keys.0.values() { + f.write_all(key.0.to_key_format().as_bytes()).await?; + f.write_all(b"\n").await?; + } + Ok(()) } diff --git a/core/startos/src/status/mod.rs b/core/startos/src/status/mod.rs index ffc1a98bb..721b47511 100644 --- a/core/startos/src/status/mod.rs +++ b/core/startos/src/status/mod.rs @@ -27,6 +27,12 @@ pub struct DependencyConfigErrors(pub BTreeMap); impl Map for DependencyConfigErrors { type Key = PackageId; type Value = String; + fn key_str(key: &Self::Key) -> Result, Error> { + Ok(key) + } + fn key_string(key: &Self::Key) -> Result { + Ok(key.clone().into()) + } } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] diff --git a/core/startos/src/update/mod.rs b/core/startos/src/update/mod.rs index 31693aba0..9f2b58135 100644 --- a/core/startos/src/update/mod.rs +++ b/core/startos/src/update/mod.rs @@ -18,7 +18,7 @@ use crate::db::model::UpdateProgress; use crate::disk::mount::filesystem::bind::Bind; use crate::disk::mount::filesystem::ReadWrite; use crate::disk::mount::guard::MountGuard; -use crate::notifications::NotificationLevel; +use crate::notifications::{notify, NotificationLevel}; use crate::prelude::*; use crate::registry::marketplace::with_query_params; use crate::sound::{ @@ -66,7 +66,7 @@ pub enum UpdateResult { Updating, } -pub fn display_update_result(params: UpdateSystemParams, status: UpdateResult) { +pub fn display_update_result(_: UpdateSystemParams, status: UpdateResult) { match status { UpdateResult::Updating => { println!("Updating..."); @@ -131,24 +131,14 @@ async fn maybe_do_update(ctx: RpcContext, marketplace_url: Url) -> Result