Merge branch 'next/minor' of github.com:Start9Labs/start-os into feature/registry-metrics

This commit is contained in:
Aiden McClelland
2024-06-24 16:24:31 -06:00
274 changed files with 8258 additions and 7057 deletions

892
core/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -24,7 +24,7 @@ pub use service_interface::ServiceInterfaceId;
pub use volume::VolumeId;
lazy_static::lazy_static! {
static ref ID_REGEX: Regex = Regex::new("^[a-z]+(-[a-z]+)*$").unwrap();
static ref ID_REGEX: Regex = Regex::new("^[a-z]+(-[a-z0-9]+)*$").unwrap();
pub static ref SYSTEM_ID: Id = Id(InternedString::intern("x_system"));
}

View File

@@ -59,6 +59,8 @@ async-stream = "0.3.5"
async-trait = "0.1.74"
axum = { version = "0.7.3", features = ["ws"] }
axum-server = "0.6.0"
barrage = "0.2.3"
backhand = "0.18.0"
base32 = "0.4.0"
base64 = "0.21.4"
base64ct = "1.6.0"
@@ -72,7 +74,6 @@ console = "0.15.7"
console-subscriber = { version = "0.2", optional = true }
cookie = "0.18.0"
cookie_store = "0.20.0"
current_platform = "0.2.0"
der = { version = "0.7.9", features = ["derive", "pem"] }
digest = "0.10.7"
divrem = "1.0.0"
@@ -96,13 +97,14 @@ hex = "0.4.3"
hmac = "0.12.1"
http = "1.0.0"
http-body-util = "0.1"
hyper-util = { version = "0.1.5", features = ["tokio", "service"] }
id-pool = { version = "0.2.2", default-features = false, features = [
"serde",
"u16",
] }
imbl = "2.0.2"
imbl-value = { git = "https://github.com/Start9Labs/imbl-value.git" }
include_dir = "0.7.3"
include_dir = { version = "0.7.3", features = ["metadata"] }
indexmap = { version = "2.0.2", features = ["serde"] }
indicatif = { version = "0.17.7", features = ["tokio"] }
integer-encoding = { version = "4.0.0", features = ["tokio_async"] }
@@ -154,10 +156,11 @@ serde_json = "1.0"
serde_toml = { package = "toml", version = "0.8.2" }
serde_urlencoded = "0.7"
serde_with = { version = "3.4.0", features = ["macros", "json"] }
serde_yaml = "0.9.25"
serde_yaml = { package = "serde_yml", version = "0.0.10" }
sha2 = "0.10.2"
shell-words = "1"
simple-logging = "2.0.2"
socket2 = "0.5.7"
sqlx = { version = "0.7.2", features = [
"chrono",
"runtime-tokio-rustls",
@@ -178,6 +181,7 @@ tokio-util = { version = "0.7.9", features = ["io"] }
torut = { git = "https://github.com/Start9Labs/torut.git", branch = "update/dependencies", features = [
"serialize",
] }
tower-service = "0.3.2"
tracing = "0.1.39"
tracing-error = "0.2.0"
tracing-futures = "0.2.5"

View File

@@ -8,6 +8,7 @@ use ts_rs::TS;
use crate::config::Config;
use crate::context::RpcContext;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::util::serde::{display_serializable, StdinDeserializable, WithIoFormat};
#[derive(Debug, Serialize, Deserialize)]
@@ -77,6 +78,7 @@ pub async fn action(
.as_ref()
.or_not_found(lazy_format!("Manager for {}", package_id))?
.action(
Guid::new(),
action_id,
input.map(|c| to_value(&c)).transpose()?.unwrap_or_default(),
)

View File

@@ -178,6 +178,7 @@ pub fn check_password_against_db(db: &DatabaseModel, password: &str) -> Result<(
#[derive(Deserialize, Serialize, Parser, TS)]
#[serde(rename_all = "camelCase")]
#[command(rename_all = "kebab-case")]
#[ts(export)]
pub struct LoginParams {
password: Option<PasswordType>,
#[ts(skip)]

View File

@@ -4,25 +4,25 @@ use std::sync::Arc;
use clap::Parser;
use futures::{stream, StreamExt};
use models::PackageId;
use openssl::x509::X509;
use patch_db::json_ptr::ROOT;
use serde::{Deserialize, Serialize};
use torut::onion::OnionAddressV3;
use tokio::sync::Mutex;
use tracing::instrument;
use ts_rs::TS;
use super::target::BackupTargetId;
use crate::backup::os::OsBackup;
use crate::context::setup::SetupResult;
use crate::context::{RpcContext, SetupContext};
use crate::db::model::Database;
use crate::disk::mount::backup::BackupMountGuard;
use crate::disk::mount::filesystem::ReadWrite;
use crate::disk::mount::guard::{GenericMountGuard, TmpMountGuard};
use crate::hostname::Hostname;
use crate::init::init;
use crate::init::{init, InitResult};
use crate::prelude::*;
use crate::s9pk::S9pk;
use crate::service::service_map::DownloadInstallFuture;
use crate::setup::SetupExecuteProgress;
use crate::util::serde::IoFormat;
#[derive(Deserialize, Serialize, Parser, TS)]
@@ -67,14 +67,21 @@ pub async fn restore_packages_rpc(
Ok(())
}
#[instrument(skip(ctx))]
#[instrument(skip_all)]
pub async fn recover_full_embassy(
ctx: SetupContext,
ctx: &SetupContext,
disk_guid: Arc<String>,
start_os_password: String,
recovery_source: TmpMountGuard,
recovery_password: Option<String>,
) -> Result<(Arc<String>, Hostname, OnionAddressV3, X509), Error> {
SetupExecuteProgress {
init_phases,
restore_phase,
rpc_ctx_phases,
}: SetupExecuteProgress,
) -> Result<(SetupResult, RpcContext), Error> {
let mut restore_phase = restore_phase.or_not_found("restore progress")?;
let backup_guard = BackupMountGuard::mount(
recovery_source,
recovery_password.as_deref().unwrap_or_default(),
@@ -99,10 +106,17 @@ pub async fn recover_full_embassy(
db.put(&ROOT, &Database::init(&os_backup.account)?).await?;
drop(db);
init(&ctx.config).await?;
let InitResult { net_ctrl } = init(&ctx.config, init_phases).await?;
let rpc_ctx = RpcContext::init(&ctx.config, disk_guid.clone()).await?;
let rpc_ctx = RpcContext::init(
&ctx.config,
disk_guid.clone(),
Some(net_ctrl),
rpc_ctx_phases,
)
.await?;
restore_phase.start();
let ids: Vec<_> = backup_guard
.metadata
.package_backups
@@ -110,26 +124,26 @@ pub async fn recover_full_embassy(
.cloned()
.collect();
let tasks = restore_packages(&rpc_ctx, backup_guard, ids).await?;
restore_phase.set_total(tasks.len() as u64);
let restore_phase = Arc::new(Mutex::new(restore_phase));
stream::iter(tasks)
.for_each_concurrent(5, |(id, res)| async move {
match async { res.await?.await }.await {
Ok(_) => (),
Err(err) => {
tracing::error!("Error restoring package {}: {}", id, err);
tracing::debug!("{:?}", err);
.for_each_concurrent(5, |(id, res)| {
let restore_phase = restore_phase.clone();
async move {
match async { res.await?.await }.await {
Ok(_) => (),
Err(err) => {
tracing::error!("Error restoring package {}: {}", id, err);
tracing::debug!("{:?}", err);
}
}
*restore_phase.lock().await += 1;
}
})
.await;
restore_phase.lock().await.complete();
rpc_ctx.shutdown().await?;
Ok((
disk_guid,
os_backup.account.hostname,
os_backup.account.tor_key.public().get_onion_address(),
os_backup.account.root_ca_cert,
))
Ok(((&os_backup.account).try_into()?, rpc_ctx))
}
#[instrument(skip(ctx, backup_guard))]
@@ -149,7 +163,6 @@ async fn restore_packages(
S9pk::open(
backup_dir.path().join(&id).with_extension("s9pk"),
Some(&id),
true,
)
.await?,
Some(backup_dir),

View File

@@ -14,7 +14,8 @@ use crate::util::logger::EmbassyLogger;
async fn inner_main(config: &RegistryConfig) -> Result<(), Error> {
let server = async {
let ctx = RegistryContext::init(config).await?;
let server = WebServer::registry(ctx.listen, ctx.clone());
let mut server = WebServer::new(ctx.listen);
server.serve_registry(ctx.clone());
let mut shutdown_recv = ctx.shutdown.subscribe();

View File

@@ -1,47 +1,56 @@
use std::net::{Ipv6Addr, SocketAddr};
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use helpers::NonDetachingJoinHandle;
use tokio::process::Command;
use tracing::instrument;
use crate::context::config::ServerConfig;
use crate::context::{DiagnosticContext, InstallContext, SetupContext};
use crate::disk::fsck::{RepairStrategy, RequiresReboot};
use crate::context::rpc::InitRpcContextPhases;
use crate::context::{DiagnosticContext, InitContext, InstallContext, RpcContext, SetupContext};
use crate::disk::fsck::RepairStrategy;
use crate::disk::main::DEFAULT_PASSWORD;
use crate::disk::REPAIR_DISK_PATH;
use crate::firmware::update_firmware;
use crate::init::STANDBY_MODE_PATH;
use crate::firmware::{check_for_firmware_update, update_firmware};
use crate::init::{InitPhases, InitResult, STANDBY_MODE_PATH};
use crate::net::web_server::WebServer;
use crate::prelude::*;
use crate::progress::FullProgressTracker;
use crate::shutdown::Shutdown;
use crate::sound::{BEP, CHIME};
use crate::util::Invoke;
use crate::{Error, ErrorKind, ResultExt, PLATFORM};
use crate::PLATFORM;
#[instrument(skip_all)]
async fn setup_or_init(config: &ServerConfig) -> Result<Option<Shutdown>, Error> {
let song = NonDetachingJoinHandle::from(tokio::spawn(async {
loop {
BEP.play().await.unwrap();
BEP.play().await.unwrap();
tokio::time::sleep(Duration::from_secs(30)).await;
}
}));
async fn setup_or_init(
server: &mut WebServer,
config: &ServerConfig,
) -> Result<Result<(RpcContext, FullProgressTracker), Shutdown>, Error> {
if let Some(firmware) = check_for_firmware_update()
.await
.map_err(|e| {
tracing::warn!("Error checking for firmware update: {e}");
tracing::debug!("{e:?}");
})
.ok()
.and_then(|a| a)
{
let init_ctx = InitContext::init(config).await?;
let handle = &init_ctx.progress;
let mut update_phase = handle.add_phase("Updating Firmware".into(), Some(10));
let mut reboot_phase = handle.add_phase("Rebooting".into(), Some(1));
match update_firmware().await {
Ok(RequiresReboot(true)) => {
return Ok(Some(Shutdown {
export_args: None,
restart: true,
}))
}
Err(e) => {
server.serve_init(init_ctx);
update_phase.start();
if let Err(e) = update_firmware(firmware).await {
tracing::warn!("Error performing firmware update: {e}");
tracing::debug!("{e:?}");
} else {
update_phase.complete();
reboot_phase.start();
return Ok(Err(Shutdown {
export_args: None,
restart: true,
}));
}
_ => (),
}
Command::new("ln")
@@ -84,14 +93,7 @@ async fn setup_or_init(config: &ServerConfig) -> Result<Option<Shutdown>, Error>
let ctx = InstallContext::init().await?;
let server = WebServer::install(
SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 80),
ctx.clone(),
)?;
drop(song);
tokio::time::sleep(Duration::from_secs(1)).await; // let the record state that I hate this
CHIME.play().await?;
server.serve_install(ctx.clone());
ctx.shutdown
.subscribe()
@@ -99,33 +101,23 @@ async fn setup_or_init(config: &ServerConfig) -> Result<Option<Shutdown>, Error>
.await
.expect("context dropped");
server.shutdown().await;
return Ok(Err(Shutdown {
export_args: None,
restart: true,
}));
}
Command::new("reboot")
.invoke(crate::ErrorKind::Unknown)
.await?;
} else if tokio::fs::metadata("/media/startos/config/disk.guid")
if tokio::fs::metadata("/media/startos/config/disk.guid")
.await
.is_err()
{
let ctx = SetupContext::init(config)?;
let server = WebServer::setup(
SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 80),
ctx.clone(),
)?;
drop(song);
tokio::time::sleep(Duration::from_secs(1)).await; // let the record state that I hate this
CHIME.play().await?;
server.serve_setup(ctx.clone());
let mut shutdown = ctx.shutdown.subscribe();
shutdown.recv().await.expect("context dropped");
server.shutdown().await;
drop(shutdown);
tokio::task::yield_now().await;
if let Err(e) = Command::new("killall")
.arg("firefox-esr")
@@ -135,19 +127,40 @@ async fn setup_or_init(config: &ServerConfig) -> Result<Option<Shutdown>, Error>
tracing::error!("Failed to kill kiosk: {}", e);
tracing::debug!("{:?}", e);
}
Ok(Ok(match ctx.result.get() {
Some(Ok((_, rpc_ctx))) => (rpc_ctx.clone(), ctx.progress.clone()),
Some(Err(e)) => return Err(e.clone_output()),
None => {
return Err(Error::new(
eyre!("Setup mode exited before setup completed"),
ErrorKind::Unknown,
))
}
}))
} else {
let init_ctx = InitContext::init(config).await?;
let handle = init_ctx.progress.clone();
let mut disk_phase = handle.add_phase("Opening data drive".into(), Some(10));
let init_phases = InitPhases::new(&handle);
let rpc_ctx_phases = InitRpcContextPhases::new(&handle);
server.serve_init(init_ctx);
disk_phase.start();
let guid_string = tokio::fs::read_to_string("/media/startos/config/disk.guid") // unique identifier for volume group - keeps track of the disk that goes with your embassy
.await?;
let guid = guid_string.trim();
let disk_guid = Arc::new(String::from(guid_string.trim()));
let requires_reboot = crate::disk::main::import(
guid,
&**disk_guid,
config.datadir(),
if tokio::fs::metadata(REPAIR_DISK_PATH).await.is_ok() {
RepairStrategy::Aggressive
} else {
RepairStrategy::Preen
},
if guid.ends_with("_UNENC") {
if disk_guid.ends_with("_UNENC") {
None
} else {
Some(DEFAULT_PASSWORD)
@@ -159,40 +172,31 @@ async fn setup_or_init(config: &ServerConfig) -> Result<Option<Shutdown>, Error>
.await
.with_ctx(|_| (crate::ErrorKind::Filesystem, REPAIR_DISK_PATH))?;
}
if requires_reboot.0 {
crate::disk::main::export(guid, config.datadir()).await?;
Command::new("reboot")
.invoke(crate::ErrorKind::Unknown)
.await?;
}
disk_phase.complete();
tracing::info!("Loaded Disk");
crate::init::init(config).await?;
drop(song);
}
Ok(None)
}
async fn run_script_if_exists<P: AsRef<Path>>(path: P) {
let script = path.as_ref();
if script.exists() {
match Command::new("/bin/bash").arg(script).spawn() {
Ok(mut c) => {
if let Err(e) = c.wait().await {
tracing::error!("Error Running {}: {}", script.display(), e);
tracing::debug!("{:?}", e);
}
}
Err(e) => {
tracing::error!("Error Running {}: {}", script.display(), e);
tracing::debug!("{:?}", e);
}
if requires_reboot.0 {
let mut reboot_phase = handle.add_phase("Rebooting".into(), Some(1));
reboot_phase.start();
return Ok(Err(Shutdown {
export_args: Some((disk_guid, config.datadir().to_owned())),
restart: true,
}));
}
let InitResult { net_ctrl } = crate::init::init(config, init_phases).await?;
let rpc_ctx = RpcContext::init(config, disk_guid, Some(net_ctrl), rpc_ctx_phases).await?;
Ok(Ok((rpc_ctx, handle)))
}
}
#[instrument(skip_all)]
async fn inner_main(config: &ServerConfig) -> Result<Option<Shutdown>, Error> {
pub async fn main(
server: &mut WebServer,
config: &ServerConfig,
) -> Result<Result<(RpcContext, FullProgressTracker), Shutdown>, Error> {
if &*PLATFORM == "raspberrypi" && tokio::fs::metadata(STANDBY_MODE_PATH).await.is_ok() {
tokio::fs::remove_file(STANDBY_MODE_PATH).await?;
Command::new("sync").invoke(ErrorKind::Filesystem).await?;
@@ -200,16 +204,11 @@ async fn inner_main(config: &ServerConfig) -> Result<Option<Shutdown>, Error> {
futures::future::pending::<()>().await;
}
crate::sound::BEP.play().await?;
run_script_if_exists("/media/startos/config/preinit.sh").await;
let res = match setup_or_init(config).await {
let res = match setup_or_init(server, config).await {
Err(e) => {
async move {
tracing::error!("{}", e.source);
tracing::debug!("{}", e.source);
crate::sound::BEETHOVEN.play().await?;
tracing::error!("{e}");
tracing::debug!("{e:?}");
let ctx = DiagnosticContext::init(
config,
@@ -229,44 +228,16 @@ async fn inner_main(config: &ServerConfig) -> Result<Option<Shutdown>, Error> {
e,
)?;
let server = WebServer::diagnostic(
SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 80),
ctx.clone(),
)?;
server.serve_diagnostic(ctx.clone());
let shutdown = ctx.shutdown.subscribe().recv().await.unwrap();
server.shutdown().await;
Ok(shutdown)
Ok(Err(shutdown))
}
.await
}
Ok(s) => Ok(s),
};
run_script_if_exists("/media/startos/config/postinit.sh").await;
res
}
pub fn main(config: &ServerConfig) {
let res = {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("failed to initialize runtime");
rt.block_on(inner_main(config))
};
match res {
Ok(Some(shutdown)) => shutdown.execute(),
Ok(None) => (),
Err(e) => {
eprintln!("{}", e.source);
tracing::debug!("{:?}", e.source);
drop(e.source);
std::process::exit(e.kind as i32)
}
}
}

View File

@@ -1,6 +1,5 @@
use std::ffi::OsString;
use std::net::{Ipv6Addr, SocketAddr};
use std::path::Path;
use std::sync::Arc;
use clap::Parser;
@@ -10,7 +9,8 @@ use tokio::signal::unix::signal;
use tracing::instrument;
use crate::context::config::ServerConfig;
use crate::context::{DiagnosticContext, RpcContext};
use crate::context::rpc::InitRpcContextPhases;
use crate::context::{DiagnosticContext, InitContext, RpcContext};
use crate::net::web_server::WebServer;
use crate::shutdown::Shutdown;
use crate::system::launch_metrics_task;
@@ -18,9 +18,31 @@ use crate::util::logger::EmbassyLogger;
use crate::{Error, ErrorKind, ResultExt};
#[instrument(skip_all)]
async fn inner_main(config: &ServerConfig) -> Result<Option<Shutdown>, Error> {
let (rpc_ctx, server, shutdown) = async {
let rpc_ctx = RpcContext::init(
async fn inner_main(
server: &mut WebServer,
config: &ServerConfig,
) -> Result<Option<Shutdown>, Error> {
let rpc_ctx = if !tokio::fs::metadata("/run/startos/initialized")
.await
.is_ok()
{
let (ctx, handle) = match super::start_init::main(server, &config).await? {
Err(s) => return Ok(Some(s)),
Ok(ctx) => ctx,
};
tokio::fs::write("/run/startos/initialized", "").await?;
server.serve_main(ctx.clone());
handle.complete();
ctx
} else {
let init_ctx = InitContext::init(config).await?;
let handle = init_ctx.progress.clone();
let rpc_ctx_phases = InitRpcContextPhases::new(&handle);
server.serve_init(init_ctx);
let ctx = RpcContext::init(
config,
Arc::new(
tokio::fs::read_to_string("/media/startos/config/disk.guid") // unique identifier for volume group - keeps track of the disk that goes with your embassy
@@ -28,13 +50,19 @@ async fn inner_main(config: &ServerConfig) -> Result<Option<Shutdown>, Error> {
.trim()
.to_owned(),
),
None,
rpc_ctx_phases,
)
.await?;
server.serve_main(ctx.clone());
handle.complete();
ctx
};
let (rpc_ctx, shutdown) = async {
crate::hostname::sync_hostname(&rpc_ctx.account.read().await.hostname).await?;
let server = WebServer::main(
SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 80),
rpc_ctx.clone(),
)?;
let mut shutdown_recv = rpc_ctx.shutdown.subscribe();
@@ -74,8 +102,6 @@ async fn inner_main(config: &ServerConfig) -> Result<Option<Shutdown>, Error> {
.await
});
crate::sound::CHIME.play().await?;
metrics_task
.map_err(|e| {
Error::new(
@@ -93,10 +119,9 @@ async fn inner_main(config: &ServerConfig) -> Result<Option<Shutdown>, Error> {
sig_handler.abort();
Ok::<_, Error>((rpc_ctx, server, shutdown))
Ok::<_, Error>((rpc_ctx, shutdown))
}
.await?;
server.shutdown().await;
rpc_ctx.shutdown().await?;
tracing::info!("RPC Context is dropped");
@@ -109,24 +134,22 @@ pub fn main(args: impl IntoIterator<Item = OsString>) {
let config = ServerConfig::parse_from(args).load().unwrap();
if !Path::new("/run/embassy/initialized").exists() {
super::start_init::main(&config);
std::fs::write("/run/embassy/initialized", "").unwrap();
}
let res = {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("failed to initialize runtime");
rt.block_on(async {
match inner_main(&config).await {
Ok(a) => Ok(a),
let mut server = WebServer::new(SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 80));
match inner_main(&mut server, &config).await {
Ok(a) => {
server.shutdown().await;
Ok(a)
}
Err(e) => {
async {
tracing::error!("{}", e.source);
tracing::debug!("{:?}", e.source);
crate::sound::BEETHOVEN.play().await?;
tracing::error!("{e}");
tracing::debug!("{e:?}");
let ctx = DiagnosticContext::init(
&config,
if tokio::fs::metadata("/media/startos/config/disk.guid")
@@ -145,10 +168,7 @@ pub fn main(args: impl IntoIterator<Item = OsString>) {
e,
)?;
let server = WebServer::diagnostic(
SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 80),
ctx.clone(),
)?;
server.serve_diagnostic(ctx.clone());
let mut shutdown = ctx.shutdown.subscribe();
@@ -157,7 +177,7 @@ pub fn main(args: impl IntoIterator<Item = OsString>) {
server.shutdown().await;
Ok::<_, Error>(shutdown)
Ok::<_, Error>(Some(shutdown))
}
.await
}

View File

@@ -16,6 +16,7 @@ use ts_rs::TS;
use crate::context::{CliContext, RpcContext};
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::util::serde::{HandlerExtSerde, StdinDeserializable};
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
@@ -156,7 +157,7 @@ pub async fn get(ctx: RpcContext, _: Empty, id: PackageId) -> Result<ConfigRes,
.await
.as_ref()
.or_not_found(lazy_format!("Manager for {id}"))?
.get_config()
.get_config(Guid::new())
.await
}
@@ -218,7 +219,7 @@ pub async fn set_impl(
ErrorKind::Unknown,
)
})?
.configure(configure_context)
.configure(Guid::new(), configure_context)
.await?;
Ok(())
}

View File

@@ -18,7 +18,7 @@ use tracing::instrument;
use super::setup::CURRENT_SECRET;
use crate::context::config::{local_config_path, ClientConfig};
use crate::context::{DiagnosticContext, InstallContext, RpcContext, SetupContext};
use crate::context::{DiagnosticContext, InitContext, InstallContext, RpcContext, SetupContext};
use crate::middleware::auth::LOCAL_AUTH_COOKIE_PATH;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
@@ -271,6 +271,11 @@ impl CallRemote<DiagnosticContext> for CliContext {
call_remote_http(&self.client, self.rpc_url.clone(), method, params).await
}
}
impl CallRemote<InitContext> for CliContext {
async fn call_remote(&self, method: &str, params: Value, _: Empty) -> Result<Value, RpcError> {
call_remote_http(&self.client, self.rpc_url.clone(), method, params).await
}
}
impl CallRemote<SetupContext> for CliContext {
async fn call_remote(&self, method: &str, params: Value, _: Empty) -> Result<Value, RpcError> {
call_remote_http(&self.client, self.rpc_url.clone(), method, params).await

View File

@@ -93,26 +93,28 @@ impl ClientConfig {
#[serde(rename_all = "kebab-case")]
#[command(rename_all = "kebab-case")]
pub struct ServerConfig {
#[arg(short = 'c', long = "config")]
#[arg(short, long)]
pub config: Option<PathBuf>,
#[arg(long = "ethernet-interface")]
#[arg(long)]
pub ethernet_interface: Option<String>,
#[arg(skip)]
pub os_partitions: Option<OsPartitionInfo>,
#[arg(long = "bind-rpc")]
#[arg(long)]
pub bind_rpc: Option<SocketAddr>,
#[arg(long = "tor-control")]
#[arg(long)]
pub tor_control: Option<SocketAddr>,
#[arg(long = "tor-socks")]
#[arg(long)]
pub tor_socks: Option<SocketAddr>,
#[arg(long = "dns-bind")]
#[arg(long)]
pub dns_bind: Option<Vec<SocketAddr>>,
#[arg(long = "revision-cache-size")]
#[arg(long)]
pub revision_cache_size: Option<usize>,
#[arg(short = 'd', long = "datadir")]
#[arg(short, long)]
pub datadir: Option<PathBuf>,
#[arg(long = "disable-encryption")]
#[arg(long)]
pub disable_encryption: Option<bool>,
#[arg(long)]
pub multi_arch_s9pks: Option<bool>,
}
impl ContextConfig for ServerConfig {
fn next(&mut self) -> Option<PathBuf> {
@@ -131,6 +133,7 @@ impl ContextConfig for ServerConfig {
.or(other.revision_cache_size);
self.datadir = self.datadir.take().or(other.datadir);
self.disable_encryption = self.disable_encryption.take().or(other.disable_encryption);
self.multi_arch_s9pks = self.multi_arch_s9pks.take().or(other.multi_arch_s9pks);
}
}

View File

@@ -14,7 +14,7 @@ use crate::Error;
pub struct DiagnosticContextSeed {
pub datadir: PathBuf,
pub shutdown: Sender<Option<Shutdown>>,
pub shutdown: Sender<Shutdown>,
pub error: Arc<RpcError>,
pub disk_guid: Option<Arc<String>>,
pub rpc_continuations: RpcContinuations,

View File

@@ -0,0 +1,47 @@
use std::ops::Deref;
use std::sync::Arc;
use rpc_toolkit::Context;
use tokio::sync::broadcast::Sender;
use tracing::instrument;
use crate::context::config::ServerConfig;
use crate::progress::FullProgressTracker;
use crate::rpc_continuations::RpcContinuations;
use crate::Error;
pub struct InitContextSeed {
pub config: ServerConfig,
pub progress: FullProgressTracker,
pub shutdown: Sender<()>,
pub rpc_continuations: RpcContinuations,
}
#[derive(Clone)]
pub struct InitContext(Arc<InitContextSeed>);
impl InitContext {
#[instrument(skip_all)]
pub async fn init(cfg: &ServerConfig) -> Result<Self, Error> {
let (shutdown, _) = tokio::sync::broadcast::channel(1);
Ok(Self(Arc::new(InitContextSeed {
config: cfg.clone(),
progress: FullProgressTracker::new(),
shutdown,
rpc_continuations: RpcContinuations::new(),
})))
}
}
impl AsRef<RpcContinuations> for InitContext {
fn as_ref(&self) -> &RpcContinuations {
&self.rpc_continuations
}
}
impl Context for InitContext {}
impl Deref for InitContext {
type Target = InitContextSeed;
fn deref(&self) -> &Self::Target {
&*self.0
}
}

View File

@@ -6,11 +6,13 @@ use tokio::sync::broadcast::Sender;
use tracing::instrument;
use crate::net::utils::find_eth_iface;
use crate::rpc_continuations::RpcContinuations;
use crate::Error;
pub struct InstallContextSeed {
pub ethernet_interface: String,
pub shutdown: Sender<()>,
pub rpc_continuations: RpcContinuations,
}
#[derive(Clone)]
@@ -22,10 +24,17 @@ impl InstallContext {
Ok(Self(Arc::new(InstallContextSeed {
ethernet_interface: find_eth_iface().await?,
shutdown,
rpc_continuations: RpcContinuations::new(),
})))
}
}
impl AsRef<RpcContinuations> for InstallContext {
fn as_ref(&self) -> &RpcContinuations {
&self.rpc_continuations
}
}
impl Context for InstallContext {}
impl Deref for InstallContext {
type Target = InstallContextSeed;

View File

@@ -1,12 +1,14 @@
pub mod cli;
pub mod config;
pub mod diagnostic;
pub mod init;
pub mod install;
pub mod rpc;
pub mod setup;
pub use cli::CliContext;
pub use diagnostic::DiagnosticContext;
pub use init::InitContext;
pub use install::InstallContext;
pub use rpc::RpcContext;
pub use setup::SetupContext;

View File

@@ -6,11 +6,12 @@ use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use imbl_value::InternedString;
use josekit::jwk::Jwk;
use reqwest::{Client, Proxy};
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{CallRemote, Context, Empty};
use tokio::sync::{broadcast, oneshot, Mutex, RwLock};
use tokio::sync::{broadcast, Mutex, RwLock};
use tokio::time::Instant;
use tracing::instrument;
use url::Url;
@@ -23,12 +24,12 @@ use crate::dependencies::compute_dependency_config_errs;
use crate::disk::OsPartitionInfo;
use crate::init::check_time_is_synchronized;
use crate::lxc::{ContainerId, LxcContainer, LxcManager};
use crate::middleware::auth::HashSessionToken;
use crate::net::net_controller::NetController;
use crate::net::net_controller::{NetController, PreInitNetController};
use crate::net::utils::{find_eth_iface, find_wifi_iface};
use crate::net::wifi::WpaCli;
use crate::prelude::*;
use crate::rpc_continuations::RpcContinuations;
use crate::progress::{FullProgressTracker, PhaseProgressTrackerHandle};
use crate::rpc_continuations::{OpenAuthedContinuations, RpcContinuations};
use crate::service::ServiceMap;
use crate::shutdown::Shutdown;
use crate::system::get_mem_info;
@@ -44,12 +45,13 @@ pub struct RpcContextSeed {
pub db: TypedPatchDb<Database>,
pub account: RwLock<AccountInfo>,
pub net_controller: Arc<NetController>,
pub s9pk_arch: Option<&'static str>,
pub services: ServiceMap,
pub metrics_cache: RwLock<Option<crate::system::Metrics>>,
pub shutdown: broadcast::Sender<Option<Shutdown>>,
pub tor_socks: SocketAddr,
pub lxc_manager: Arc<LxcManager>,
pub open_authed_websockets: Mutex<BTreeMap<HashSessionToken, Vec<oneshot::Sender<()>>>>,
pub open_authed_continuations: OpenAuthedContinuations<InternedString>,
pub rpc_continuations: RpcContinuations,
pub wifi_manager: Option<Arc<RwLock<WpaCli>>>,
pub current_secret: Arc<Jwk>,
@@ -69,45 +71,103 @@ pub struct Hardware {
pub ram: u64,
}
pub struct InitRpcContextPhases {
load_db: PhaseProgressTrackerHandle,
init_net_ctrl: PhaseProgressTrackerHandle,
read_device_info: PhaseProgressTrackerHandle,
cleanup_init: CleanupInitPhases,
}
impl InitRpcContextPhases {
pub fn new(handle: &FullProgressTracker) -> Self {
Self {
load_db: handle.add_phase("Loading database".into(), Some(5)),
init_net_ctrl: handle.add_phase("Initializing network".into(), Some(1)),
read_device_info: handle.add_phase("Reading device information".into(), Some(1)),
cleanup_init: CleanupInitPhases::new(handle),
}
}
}
pub struct CleanupInitPhases {
init_services: PhaseProgressTrackerHandle,
check_dependencies: PhaseProgressTrackerHandle,
}
impl CleanupInitPhases {
pub fn new(handle: &FullProgressTracker) -> Self {
Self {
init_services: handle.add_phase("Initializing services".into(), Some(10)),
check_dependencies: handle.add_phase("Checking dependencies".into(), Some(1)),
}
}
}
#[derive(Clone)]
pub struct RpcContext(Arc<RpcContextSeed>);
impl RpcContext {
#[instrument(skip_all)]
pub async fn init(config: &ServerConfig, disk_guid: Arc<String>) -> Result<Self, Error> {
tracing::info!("Loaded Config");
pub async fn init(
config: &ServerConfig,
disk_guid: Arc<String>,
net_ctrl: Option<PreInitNetController>,
InitRpcContextPhases {
mut load_db,
mut init_net_ctrl,
mut read_device_info,
cleanup_init,
}: InitRpcContextPhases,
) -> Result<Self, Error> {
let tor_proxy = config.tor_socks.unwrap_or(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(127, 0, 0, 1),
9050,
)));
let (shutdown, _) = tokio::sync::broadcast::channel(1);
let db = TypedPatchDb::<Database>::load(config.db().await?).await?;
load_db.start();
let db = if let Some(net_ctrl) = &net_ctrl {
net_ctrl.db.clone()
} else {
TypedPatchDb::<Database>::load(config.db().await?).await?
};
let peek = db.peek().await;
let account = AccountInfo::load(&peek)?;
load_db.complete();
tracing::info!("Opened PatchDB");
init_net_ctrl.start();
let net_controller = Arc::new(
NetController::init(
db.clone(),
config
.tor_control
.unwrap_or(SocketAddr::from(([127, 0, 0, 1], 9051))),
tor_proxy,
if let Some(net_ctrl) = net_ctrl {
net_ctrl
} else {
PreInitNetController::init(
db.clone(),
config
.tor_control
.unwrap_or(SocketAddr::from(([127, 0, 0, 1], 9051))),
tor_proxy,
&account.hostname,
account.tor_key.clone(),
)
.await?
},
config
.dns_bind
.as_deref()
.unwrap_or(&[SocketAddr::from(([127, 0, 0, 1], 53))]),
&account.hostname,
account.tor_key.clone(),
)
.await?,
);
init_net_ctrl.complete();
tracing::info!("Initialized Net Controller");
let services = ServiceMap::default();
let metrics_cache = RwLock::<Option<crate::system::Metrics>>::new(None);
tracing::info!("Initialized Notification Manager");
let tor_proxy_url = format!("socks5h://{tor_proxy}");
read_device_info.start();
let devices = lshw().await?;
let ram = get_mem_info().await?.total.0 as u64 * 1024 * 1024;
read_device_info.complete();
if !db
.peek()
@@ -154,12 +214,17 @@ impl RpcContext {
db,
account: RwLock::new(account),
net_controller,
s9pk_arch: if config.multi_arch_s9pks.unwrap_or(false) {
None
} else {
Some(crate::ARCH)
},
services,
metrics_cache,
shutdown,
tor_socks: tor_proxy,
lxc_manager: Arc::new(LxcManager::new()),
open_authed_websockets: Mutex::new(BTreeMap::new()),
open_authed_continuations: OpenAuthedContinuations::new(),
rpc_continuations: RpcContinuations::new(),
wifi_manager: wifi_interface
.clone()
@@ -193,7 +258,7 @@ impl RpcContext {
});
let res = Self(seed.clone());
res.cleanup_and_initialize().await?;
res.cleanup_and_initialize(cleanup_init).await?;
tracing::info!("Cleaned up transient states");
Ok(res)
}
@@ -207,11 +272,18 @@ impl RpcContext {
Ok(())
}
#[instrument(skip(self))]
pub async fn cleanup_and_initialize(&self) -> Result<(), Error> {
self.services.init(&self).await?;
#[instrument(skip_all)]
pub async fn cleanup_and_initialize(
&self,
CleanupInitPhases {
init_services,
mut check_dependencies,
}: CleanupInitPhases,
) -> Result<(), Error> {
self.services.init(&self, init_services).await?;
tracing::info!("Initialized Package Managers");
check_dependencies.start();
let mut updated_current_dependents = BTreeMap::new();
let peek = self.db.peek().await;
for (package_id, package) in peek.as_public().as_package_data().as_entries()?.into_iter() {
@@ -235,6 +307,7 @@ impl RpcContext {
Ok(())
})
.await?;
check_dependencies.complete();
Ok(())
}
@@ -271,6 +344,11 @@ impl AsRef<RpcContinuations> for RpcContext {
&self.rpc_continuations
}
}
impl AsRef<OpenAuthedContinuations<InternedString>> for RpcContext {
fn as_ref(&self) -> &OpenAuthedContinuations<InternedString> {
&self.open_authed_continuations
}
}
impl Context for RpcContext {}
impl Deref for RpcContext {
type Target = RpcContextSeed;

View File

@@ -1,23 +1,31 @@
use std::ops::Deref;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use futures::{Future, StreamExt};
use helpers::NonDetachingJoinHandle;
use josekit::jwk::Jwk;
use patch_db::PatchDb;
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::Context;
use serde::{Deserialize, Serialize};
use sqlx::postgres::PgConnectOptions;
use sqlx::PgPool;
use tokio::sync::broadcast::Sender;
use tokio::sync::RwLock;
use tokio::sync::OnceCell;
use tracing::instrument;
use ts_rs::TS;
use crate::account::AccountInfo;
use crate::context::config::ServerConfig;
use crate::context::RpcContext;
use crate::disk::OsPartitionInfo;
use crate::init::init_postgres;
use crate::prelude::*;
use crate::setup::SetupStatus;
use crate::progress::FullProgressTracker;
use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations};
use crate::setup::SetupProgress;
use crate::util::net::WebSocketExt;
lazy_static::lazy_static! {
pub static ref CURRENT_SECRET: Jwk = Jwk::generate_ec_key(josekit::jwk::alg::ec::EcCurve::P256).unwrap_or_else(|e| {
@@ -27,30 +35,35 @@ lazy_static::lazy_static! {
});
}
#[derive(Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct SetupResult {
pub tor_address: String,
pub lan_address: String,
pub root_ca: String,
}
impl TryFrom<&AccountInfo> for SetupResult {
type Error = Error;
fn try_from(value: &AccountInfo) -> Result<Self, Self::Error> {
Ok(Self {
tor_address: format!("https://{}", value.tor_key.public().get_onion_address()),
lan_address: value.hostname.lan_address(),
root_ca: String::from_utf8(value.root_ca_cert.to_pem()?)?,
})
}
}
pub struct SetupContextSeed {
pub config: ServerConfig,
pub os_partitions: OsPartitionInfo,
pub disable_encryption: bool,
pub progress: FullProgressTracker,
pub task: OnceCell<NonDetachingJoinHandle<()>>,
pub result: OnceCell<Result<(SetupResult, RpcContext), Error>>,
pub shutdown: Sender<()>,
pub datadir: PathBuf,
pub selected_v2_drive: RwLock<Option<PathBuf>>,
pub cached_product_key: RwLock<Option<Arc<String>>>,
pub setup_status: RwLock<Option<Result<SetupStatus, RpcError>>>,
pub setup_result: RwLock<Option<(Arc<String>, SetupResult)>>,
}
impl AsRef<Jwk> for SetupContextSeed {
fn as_ref(&self) -> &Jwk {
&*CURRENT_SECRET
}
pub rpc_continuations: RpcContinuations,
}
#[derive(Clone)]
@@ -69,12 +82,12 @@ impl SetupContext {
)
})?,
disable_encryption: config.disable_encryption.unwrap_or(false),
progress: FullProgressTracker::new(),
task: OnceCell::new(),
result: OnceCell::new(),
shutdown,
datadir,
selected_v2_drive: RwLock::new(None),
cached_product_key: RwLock::new(None),
setup_status: RwLock::new(None),
setup_result: RwLock::new(None),
rpc_continuations: RpcContinuations::new(),
})))
}
#[instrument(skip_all)]
@@ -97,6 +110,104 @@ impl SetupContext {
.with_kind(crate::ErrorKind::Database)?;
Ok(secret_store)
}
pub fn run_setup<F, Fut>(&self, f: F) -> Result<(), Error>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = Result<(SetupResult, RpcContext), Error>> + Send,
{
let local_ctx = self.clone();
self.task
.set(
tokio::spawn(async move {
local_ctx
.result
.get_or_init(|| async {
match f().await {
Ok(res) => {
tracing::info!("Setup complete!");
Ok(res)
}
Err(e) => {
tracing::error!("Setup failed: {e}");
tracing::debug!("{e:?}");
Err(e)
}
}
})
.await;
local_ctx.progress.complete();
})
.into(),
)
.map_err(|_| {
if self.result.initialized() {
Error::new(eyre!("Setup already complete"), ErrorKind::InvalidRequest)
} else {
Error::new(
eyre!("Setup already in progress"),
ErrorKind::InvalidRequest,
)
}
})?;
Ok(())
}
pub async fn progress(&self) -> SetupProgress {
use axum::extract::ws;
let guid = Guid::new();
let progress_tracker = self.progress.clone();
let progress = progress_tracker.snapshot();
self.rpc_continuations
.add(
guid.clone(),
RpcContinuation::ws(
|mut ws| async move {
if let Err(e) = async {
let mut stream =
progress_tracker.stream(Some(Duration::from_millis(100)));
while let Some(progress) = stream.next().await {
ws.send(ws::Message::Text(
serde_json::to_string(&progress)
.with_kind(ErrorKind::Serialization)?,
))
.await
.with_kind(ErrorKind::Network)?;
if progress.overall.is_complete() {
break;
}
}
ws.normal_close("complete").await?;
Ok::<_, Error>(())
}
.await
{
tracing::error!("Error in setup progress websocket: {e}");
tracing::debug!("{e:?}");
}
},
Duration::from_secs(30),
),
)
.await;
SetupProgress { progress, guid }
}
}
impl AsRef<Jwk> for SetupContext {
fn as_ref(&self) -> &Jwk {
&*CURRENT_SECRET
}
}
impl AsRef<RpcContinuations> for SetupContext {
fn as_ref(&self) -> &RpcContinuations {
&self.rpc_continuations
}
}
impl Context for SetupContext {}

View File

@@ -7,6 +7,7 @@ use ts_rs::TS;
use crate::context::RpcContext;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::Error;
#[derive(Deserialize, Serialize, Parser, TS)]
@@ -23,7 +24,7 @@ pub async fn start(ctx: RpcContext, ControlParams { id }: ControlParams) -> Resu
.await
.as_ref()
.or_not_found(lazy_format!("Manager for {id}"))?
.start()
.start(Guid::new())
.await?;
Ok(())
@@ -36,7 +37,7 @@ pub async fn stop(ctx: RpcContext, ControlParams { id }: ControlParams) -> Resul
.await
.as_ref()
.ok_or_else(|| Error::new(eyre!("Manager not found"), crate::ErrorKind::InvalidRequest))?
.stop()
.stop(Guid::new())
.await?;
Ok(())
@@ -48,7 +49,7 @@ pub async fn restart(ctx: RpcContext, ControlParams { id }: ControlParams) -> Re
.await
.as_ref()
.ok_or_else(|| Error::new(eyre!("Manager not found"), crate::ErrorKind::InvalidRequest))?
.restart()
.restart(Guid::new())
.await?;
Ok(())

View File

@@ -3,175 +3,40 @@ pub mod prelude;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use axum::extract::ws::{self, WebSocket};
use axum::extract::WebSocketUpgrade;
use axum::response::Response;
use axum::extract::ws;
use clap::Parser;
use futures::{FutureExt, StreamExt};
use http::header::COOKIE;
use http::HeaderMap;
use imbl_value::InternedString;
use itertools::Itertools;
use patch_db::json_ptr::{JsonPointer, ROOT};
use patch_db::{Dump, Revision};
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{from_fn_async, Context, HandlerArgs, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::oneshot;
use tracing::instrument;
use ts_rs::TS;
use crate::context::{CliContext, RpcContext};
use crate::middleware::auth::{HasValidSession, HashSessionToken};
use crate::prelude::*;
use crate::rpc_continuations::{Guid, RpcContinuation};
use crate::util::net::WebSocketExt;
use crate::util::serde::{apply_expr, HandlerExtSerde};
lazy_static::lazy_static! {
static ref PUBLIC: JsonPointer = "/public".parse().unwrap();
}
#[instrument(skip_all)]
async fn ws_handler(
ctx: RpcContext,
session: Option<(HasValidSession, HashSessionToken)>,
mut stream: WebSocket,
) -> Result<(), Error> {
let (dump, sub) = ctx.db.dump_and_sub(PUBLIC.clone()).await;
if let Some((session, token)) = session {
let kill = subscribe_to_session_kill(&ctx, token).await;
send_dump(session.clone(), &mut stream, dump).await?;
deal_with_messages(session, kill, sub, stream).await?;
} else {
stream
.send(ws::Message::Close(Some(ws::CloseFrame {
code: ws::close_code::ERROR,
reason: "UNAUTHORIZED".into(),
})))
.await
.with_kind(ErrorKind::Network)?;
drop(stream);
}
Ok(())
}
async fn subscribe_to_session_kill(
ctx: &RpcContext,
token: HashSessionToken,
) -> oneshot::Receiver<()> {
let (send, recv) = oneshot::channel();
let mut guard = ctx.open_authed_websockets.lock().await;
if !guard.contains_key(&token) {
guard.insert(token, vec![send]);
} else {
guard.get_mut(&token).unwrap().push(send);
}
recv
}
#[instrument(skip_all)]
async fn deal_with_messages(
_has_valid_authentication: HasValidSession,
mut kill: oneshot::Receiver<()>,
mut sub: patch_db::Subscriber,
mut stream: WebSocket,
) -> Result<(), Error> {
let mut timer = tokio::time::interval(tokio::time::Duration::from_secs(5));
loop {
futures::select! {
_ = (&mut kill).fuse() => {
tracing::info!("Closing WebSocket: Reason: Session Terminated");
stream
.send(ws::Message::Close(Some(ws::CloseFrame {
code: ws::close_code::ERROR,
reason: "UNAUTHORIZED".into(),
}))).await
.with_kind(ErrorKind::Network)?;
drop(stream);
return Ok(())
}
new_rev = sub.recv().fuse() => {
let rev = new_rev.expect("UNREACHABLE: patch-db is dropped");
stream
.send(ws::Message::Text(serde_json::to_string(&rev).with_kind(ErrorKind::Serialization)?))
.await
.with_kind(ErrorKind::Network)?;
}
message = stream.next().fuse() => {
let message = message.transpose().with_kind(ErrorKind::Network)?;
match message {
None => {
tracing::info!("Closing WebSocket: Stream Finished");
return Ok(())
}
_ => (),
}
}
// This is trying to give a health checks to the home to keep the ui alive.
_ = timer.tick().fuse() => {
stream
.send(ws::Message::Ping(vec![]))
.await
.with_kind(crate::ErrorKind::Network)?;
}
}
}
}
async fn send_dump(
_has_valid_authentication: HasValidSession,
stream: &mut WebSocket,
dump: Dump,
) -> Result<(), Error> {
stream
.send(ws::Message::Text(
serde_json::to_string(&dump).with_kind(ErrorKind::Serialization)?,
))
.await
.with_kind(ErrorKind::Network)?;
Ok(())
}
pub async fn subscribe(
ctx: RpcContext,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> Result<Response, Error> {
let session = match async {
let token = HashSessionToken::from_header(headers.get(COOKIE))?;
let session = HasValidSession::from_header(headers.get(COOKIE), &ctx).await?;
Ok::<_, Error>((session, token))
}
.await
{
Ok(a) => Some(a),
Err(e) => {
if e.kind != ErrorKind::Authorization {
tracing::error!("Error Authenticating Websocket: {}", e);
tracing::debug!("{:?}", e);
}
None
}
};
Ok(ws.on_upgrade(|ws| async move {
match ws_handler(ctx, session, ws).await {
Ok(()) => (),
Err(e) => {
tracing::error!("WebSocket Closed: {}", e);
tracing::debug!("{:?}", e);
}
}
}))
}
pub fn db<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand("dump", from_fn_async(cli_dump).with_display_serializable())
.subcommand("dump", from_fn_async(dump).no_cli())
.subcommand(
"subscribe",
from_fn_async(subscribe)
.with_metadata("get_session", Value::Bool(true))
.no_cli(),
)
.subcommand("put", put::<C>())
.subcommand("apply", from_fn_async(cli_apply).no_display())
.subcommand("apply", from_fn_async(apply).no_cli())
@@ -215,7 +80,13 @@ async fn cli_dump(
context
.call_remote::<RpcContext>(
&method,
imbl_value::json!({ "includePrivate":include_private }),
imbl_value::json!({
"pointer": if include_private {
AsRef::<str>::as_ref(&ROOT)
} else {
AsRef::<str>::as_ref(&*PUBLIC)
}
}),
)
.await?,
)?
@@ -224,25 +95,76 @@ async fn cli_dump(
Ok(dump)
}
#[derive(Deserialize, Serialize, Parser, TS)]
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
#[command(rename_all = "kebab-case")]
pub struct DumpParams {
#[arg(long = "include-private", short = 'p')]
#[serde(default)]
#[ts(skip)]
include_private: bool,
#[ts(type = "string | null")]
pointer: Option<JsonPointer>,
}
pub async fn dump(
pub async fn dump(ctx: RpcContext, DumpParams { pointer }: DumpParams) -> Result<Dump, Error> {
Ok(ctx.db.dump(pointer.as_ref().unwrap_or(&*PUBLIC)).await)
}
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct SubscribeParams {
#[ts(type = "string | null")]
pointer: Option<JsonPointer>,
#[ts(skip)]
#[serde(rename = "__auth_session")]
session: InternedString,
}
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct SubscribeRes {
#[ts(type = "{ id: number; value: unknown }")]
pub dump: Dump,
pub guid: Guid,
}
pub async fn subscribe(
ctx: RpcContext,
DumpParams { include_private }: DumpParams,
) -> Result<Dump, Error> {
Ok(if include_private {
ctx.db.dump(&ROOT).await
} else {
ctx.db.dump(&PUBLIC).await
})
SubscribeParams { pointer, session }: SubscribeParams,
) -> Result<SubscribeRes, Error> {
let (dump, mut sub) = ctx
.db
.dump_and_sub(pointer.unwrap_or_else(|| PUBLIC.clone()))
.await;
let guid = Guid::new();
ctx.rpc_continuations
.add(
guid.clone(),
RpcContinuation::ws_authed(
&ctx,
session,
|mut ws| async move {
if let Err(e) = async {
while let Some(rev) = sub.recv().await {
ws.send(ws::Message::Text(
serde_json::to_string(&rev).with_kind(ErrorKind::Serialization)?,
))
.await
.with_kind(ErrorKind::Network)?;
}
ws.normal_close("complete").await?;
Ok::<_, Error>(())
}
.await
{
tracing::error!("Error in db websocket: {e}");
tracing::debug!("{e:?}");
}
},
Duration::from_secs(30),
),
)
.await;
Ok(SubscribeRes { dump, guid })
}
#[derive(Deserialize, Serialize, Parser)]

View File

@@ -10,8 +10,8 @@ use reqwest::Url;
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use crate::net::host::HostInfo;
use crate::net::service_interface::ServiceInterfaceWithHostInfo;
use crate::net::host::Hosts;
use crate::net::service_interface::ServiceInterface;
use crate::prelude::*;
use crate::progress::FullProgress;
use crate::s9pk::manifest::Manifest;
@@ -333,8 +333,8 @@ pub struct PackageDataEntry {
pub last_backup: Option<DateTime<Utc>>,
pub current_dependencies: CurrentDependencies,
pub actions: BTreeMap<ActionId, ActionMetadata>,
pub service_interfaces: BTreeMap<ServiceInterfaceId, ServiceInterfaceWithHostInfo>,
pub hosts: HostInfo,
pub service_interfaces: BTreeMap<ServiceInterfaceId, ServiceInterface>,
pub hosts: Hosts,
#[ts(type = "string[]")]
pub store_exposed_dependents: Vec<JsonPointer>,
}

View File

@@ -13,6 +13,7 @@ use crate::config::{Config, ConfigSpec, ConfigureContext};
use crate::context::RpcContext;
use crate::db::model::package::CurrentDependencies;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::Error;
pub fn dependency<C: Context>() -> ParentHandler<C> {
@@ -86,7 +87,7 @@ pub async fn configure_impl(
ErrorKind::Unknown,
)
})?
.configure(configure_context)
.configure(Guid::new(), configure_context)
.await?;
Ok(())
}
@@ -103,14 +104,15 @@ pub async fn configure_logic(
ctx: RpcContext,
(dependent_id, dependency_id): (PackageId, PackageId),
) -> Result<ConfigDryRes, Error> {
let procedure_id = Guid::new();
let dependency_guard = ctx.services.get(&dependency_id).await;
let dependency = dependency_guard.as_ref().or_not_found(&dependency_id)?;
let dependent_guard = ctx.services.get(&dependent_id).await;
let dependent = dependent_guard.as_ref().or_not_found(&dependent_id)?;
let config_res = dependency.get_config().await?;
let config_res = dependency.get_config(procedure_id.clone()).await?;
let diff = Value::Object(
dependent
.dependency_config(dependency_id, config_res.config.clone())
.dependency_config(procedure_id, dependency_id, config_res.config.clone())
.await?
.unwrap_or_default(),
);
@@ -129,6 +131,7 @@ pub async fn compute_dependency_config_errs(
id: &PackageId,
current_dependencies: &mut CurrentDependencies,
) -> Result<(), Error> {
let procedure_id = Guid::new();
let service_guard = ctx.services.get(id).await;
let service = service_guard.as_ref().or_not_found(id)?;
for (dep_id, dep_info) in current_dependencies.0.iter_mut() {
@@ -137,10 +140,10 @@ pub async fn compute_dependency_config_errs(
continue;
};
let dep_config = dependency.get_config().await?.config;
let dep_config = dependency.get_config(procedure_id.clone()).await?.config;
dep_info.config_satisfied = service
.dependency_config(dep_id.clone(), dep_config)
.dependency_config(procedure_id.clone(), dep_id.clone(), dep_config)
.await?
.is_none();
}

View File

@@ -27,10 +27,6 @@ pub fn diagnostic<C: Context>() -> ParentHandler<C> {
"kernel-logs",
from_fn_async(crate::logs::cli_logs::<DiagnosticContext, Empty>).no_display(),
)
.subcommand(
"exit",
from_fn(exit).no_display().with_call_remote::<CliContext>(),
)
.subcommand(
"restart",
from_fn(restart)
@@ -51,20 +47,15 @@ pub fn error(ctx: DiagnosticContext) -> Result<Arc<RpcError>, Error> {
Ok(ctx.error.clone())
}
pub fn exit(ctx: DiagnosticContext) -> Result<(), Error> {
ctx.shutdown.send(None).expect("receiver dropped");
Ok(())
}
pub fn restart(ctx: DiagnosticContext) -> Result<(), Error> {
ctx.shutdown
.send(Some(Shutdown {
.send(Shutdown {
export_args: ctx
.disk_guid
.clone()
.map(|guid| (guid, ctx.datadir.clone())),
restart: true,
}))
})
.expect("receiver dropped");
Ok(())
}

View File

@@ -13,7 +13,7 @@ use crate::disk::mount::util::unmount;
use crate::util::Invoke;
use crate::{Error, ErrorKind, ResultExt};
pub const PASSWORD_PATH: &'static str = "/run/embassy/password";
pub const PASSWORD_PATH: &'static str = "/run/startos/password";
pub const DEFAULT_PASSWORD: &'static str = "password";
pub const MAIN_FS_SIZE: FsSize = FsSize::Gigabytes(8);

View File

@@ -1,8 +1,6 @@
use std::path::{Path, PathBuf};
use rpc_toolkit::{
from_fn_async, CallRemoteHandler, Context, Empty, HandlerExt, ParentHandler,
};
use rpc_toolkit::{from_fn_async, CallRemoteHandler, Context, Empty, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use crate::context::{CliContext, RpcContext};

View File

@@ -178,7 +178,6 @@ impl<G: GenericMountGuard> BackupMountGuard<G> {
Ok(())
}
}
#[async_trait::async_trait]
impl<G: GenericMountGuard> GenericMountGuard for BackupMountGuard<G> {
fn path(&self) -> &Path {
if let Some(guard) = &self.encrypted_guard {

View File

@@ -6,8 +6,8 @@ use digest::generic_array::GenericArray;
use digest::{Digest, OutputSizeUser};
use sha2::Sha256;
use crate::disk::mount::filesystem::{FileSystem, ReadOnly, ReadWrite};
use crate::disk::mount::guard::{GenericMountGuard, MountGuard, TmpMountGuard};
use crate::disk::mount::filesystem::{FileSystem, ReadWrite};
use crate::disk::mount::guard::{GenericMountGuard, MountGuard};
use crate::prelude::*;
use crate::util::io::TmpDir;
@@ -94,17 +94,13 @@ impl<
}
#[derive(Debug)]
pub struct OverlayGuard {
lower: Option<TmpMountGuard>,
pub struct OverlayGuard<G: GenericMountGuard> {
lower: Option<G>,
upper: Option<TmpDir>,
inner_guard: MountGuard,
}
impl OverlayGuard {
pub async fn mount(
base: &impl FileSystem,
mountpoint: impl AsRef<Path>,
) -> Result<Self, Error> {
let lower = TmpMountGuard::mount(base, ReadOnly).await?;
impl<G: GenericMountGuard> OverlayGuard<G> {
pub async fn mount(lower: G, mountpoint: impl AsRef<Path>) -> Result<Self, Error> {
let upper = TmpDir::new().await?;
let inner_guard = MountGuard::mount(
&OverlayFs::new(
@@ -140,16 +136,15 @@ impl OverlayGuard {
}
}
}
#[async_trait::async_trait]
impl GenericMountGuard for OverlayGuard {
impl<G: GenericMountGuard> GenericMountGuard for OverlayGuard<G> {
fn path(&self) -> &Path {
self.inner_guard.path()
}
async fn unmount(mut self) -> Result<(), Error> {
async fn unmount(self) -> Result<(), Error> {
self.unmount(false).await
}
}
impl Drop for OverlayGuard {
impl<G: GenericMountGuard> Drop for OverlayGuard<G> {
fn drop(&mut self) {
let lower = self.lower.take();
let upper = self.upper.take();

View File

@@ -2,6 +2,7 @@ use std::collections::BTreeMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Weak};
use futures::Future;
use lazy_static::lazy_static;
use models::ResultExt;
use tokio::sync::Mutex;
@@ -14,23 +15,20 @@ use crate::Error;
pub const TMP_MOUNTPOINT: &'static str = "/media/startos/tmp";
#[async_trait::async_trait]
pub trait GenericMountGuard: std::fmt::Debug + Send + Sync + 'static {
fn path(&self) -> &Path;
async fn unmount(mut self) -> Result<(), Error>;
fn unmount(self) -> impl Future<Output = Result<(), Error>> + Send;
}
#[async_trait::async_trait]
impl GenericMountGuard for Never {
fn path(&self) -> &Path {
match *self {}
}
async fn unmount(mut self) -> Result<(), Error> {
async fn unmount(self) -> Result<(), Error> {
match self {}
}
}
#[async_trait::async_trait]
impl<T> GenericMountGuard for Arc<T>
where
T: GenericMountGuard,
@@ -38,7 +36,7 @@ where
fn path(&self) -> &Path {
(&**self).path()
}
async fn unmount(mut self) -> Result<(), Error> {
async fn unmount(self) -> Result<(), Error> {
if let Ok(guard) = Arc::try_unwrap(self) {
guard.unmount().await?;
}
@@ -102,12 +100,11 @@ impl Drop for MountGuard {
}
}
}
#[async_trait::async_trait]
impl GenericMountGuard for MountGuard {
fn path(&self) -> &Path {
&self.mountpoint
}
async fn unmount(mut self) -> Result<(), Error> {
async fn unmount(self) -> Result<(), Error> {
MountGuard::unmount(self, false).await
}
}
@@ -165,12 +162,11 @@ impl TmpMountGuard {
std::mem::replace(self, unmounted)
}
}
#[async_trait::async_trait]
impl GenericMountGuard for TmpMountGuard {
fn path(&self) -> &Path {
self.guard.path()
}
async fn unmount(mut self) -> Result<(), Error> {
async fn unmount(self) -> Result<(), Error> {
self.guard.unmount().await
}
}
@@ -187,12 +183,11 @@ impl<G: GenericMountGuard> SubPath<G> {
Self { guard, path }
}
}
#[async_trait::async_trait]
impl<G: GenericMountGuard> GenericMountGuard for SubPath<G> {
fn path(&self) -> &Path {
self.path.as_path()
}
async fn unmount(mut self) -> Result<(), Error> {
async fn unmount(self) -> Result<(), Error> {
self.guard.unmount().await
}
}

View File

@@ -9,6 +9,7 @@ use tokio::process::Command;
use crate::disk::fsck::RequiresReboot;
use crate::prelude::*;
use crate::progress::PhaseProgressTrackerHandle;
use crate::util::Invoke;
use crate::PLATFORM;
@@ -49,12 +50,7 @@ pub fn display_firmware_update_result(result: RequiresReboot) {
}
}
/// We wanted to make sure during every init
/// that the firmware was the correct and updated for
/// systems like the Pure System that a new firmware
/// was released and the updates where pushed through the pure os.
// #[command(rename = "update-firmware", display(display_firmware_update_result))]
pub async fn update_firmware() -> Result<RequiresReboot, Error> {
pub async fn check_for_firmware_update() -> Result<Option<Firmware>, Error> {
let system_product_name = String::from_utf8(
Command::new("dmidecode")
.arg("-s")
@@ -74,22 +70,21 @@ pub async fn update_firmware() -> Result<RequiresReboot, Error> {
.trim()
.to_owned();
if system_product_name.is_empty() || bios_version.is_empty() {
return Ok(RequiresReboot(false));
return Ok(None);
}
let firmware_dir = Path::new("/usr/lib/startos/firmware");
for firmware in serde_json::from_str::<Vec<Firmware>>(
&tokio::fs::read_to_string("/usr/lib/startos/firmware.json").await?,
)
.with_kind(ErrorKind::Deserialization)?
{
let id = firmware.id;
let matches_product_name = firmware
.system_product_name
.map_or(true, |spn| spn == system_product_name);
.as_ref()
.map_or(true, |spn| spn == &system_product_name);
let matches_bios_version = firmware
.bios_version
.as_ref()
.map_or(Some(true), |bv| {
let mut semver_str = bios_version.as_str();
if let Some(prefix) = &bv.semver_prefix {
@@ -113,35 +108,45 @@ pub async fn update_firmware() -> Result<RequiresReboot, Error> {
})
.unwrap_or(false);
if firmware.platform.contains(&*PLATFORM) && matches_product_name && matches_bios_version {
let filename = format!("{id}.rom.gz");
let firmware_path = firmware_dir.join(&filename);
Command::new("sha256sum")
.arg("-c")
.input(Some(&mut std::io::Cursor::new(format!(
"{} {}",
firmware.shasum,
firmware_path.display()
))))
.invoke(ErrorKind::Filesystem)
.await?;
let mut rdr = if tokio::fs::metadata(&firmware_path).await.is_ok() {
GzipDecoder::new(BufReader::new(File::open(&firmware_path).await?))
} else {
return Err(Error::new(
eyre!("Firmware {id}.rom.gz not found in {firmware_dir:?}"),
ErrorKind::NotFound,
));
};
Command::new("flashrom")
.arg("-p")
.arg("internal")
.arg("-w-")
.input(Some(&mut rdr))
.invoke(ErrorKind::Firmware)
.await?;
return Ok(RequiresReboot(true));
return Ok(Some(firmware));
}
}
Ok(RequiresReboot(false))
Ok(None)
}
/// We wanted to make sure during every init
/// that the firmware was the correct and updated for
/// systems like the Pure System that a new firmware
/// was released and the updates where pushed through the pure os.
pub async fn update_firmware(firmware: Firmware) -> Result<(), Error> {
let id = &firmware.id;
let firmware_dir = Path::new("/usr/lib/startos/firmware");
let filename = format!("{id}.rom.gz");
let firmware_path = firmware_dir.join(&filename);
Command::new("sha256sum")
.arg("-c")
.input(Some(&mut std::io::Cursor::new(format!(
"{} {}",
firmware.shasum,
firmware_path.display()
))))
.invoke(ErrorKind::Filesystem)
.await?;
let mut rdr = if tokio::fs::metadata(&firmware_path).await.is_ok() {
GzipDecoder::new(BufReader::new(File::open(&firmware_path).await?))
} else {
return Err(Error::new(
eyre!("Firmware {id}.rom.gz not found in {firmware_dir:?}"),
ErrorKind::NotFound,
));
};
Command::new("flashrom")
.arg("-p")
.arg("internal")
.arg("-w-")
.input(Some(&mut rdr))
.invoke(ErrorKind::Firmware)
.await?;
Ok(())
}

View File

@@ -1,25 +1,40 @@
use std::fs::Permissions;
use std::io::Cursor;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::time::{Duration, SystemTime};
use axum::extract::ws::{self, CloseFrame};
use color_eyre::eyre::eyre;
use futures::{StreamExt, TryStreamExt};
use itertools::Itertools;
use models::ResultExt;
use rand::random;
use rpc_toolkit::{from_fn_async, Context, Empty, HandlerArgs, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use tracing::instrument;
use ts_rs::TS;
use crate::account::AccountInfo;
use crate::context::config::ServerConfig;
use crate::context::{CliContext, InitContext};
use crate::db::model::public::ServerStatus;
use crate::db::model::Database;
use crate::disk::mount::util::unmount;
use crate::middleware::auth::LOCAL_AUTH_COOKIE_PATH;
use crate::net::net_controller::PreInitNetController;
use crate::prelude::*;
use crate::progress::{
FullProgress, FullProgressTracker, PhaseProgressTrackerHandle, PhasedProgressBar,
};
use crate::rpc_continuations::{Guid, RpcContinuation};
use crate::ssh::SSH_AUTHORIZED_KEYS_FILE;
use crate::util::cpupower::{get_available_governors, get_preferred_governor, set_governor};
use crate::util::Invoke;
use crate::{Error, ARCH};
use crate::util::io::IOHook;
use crate::util::net::WebSocketExt;
use crate::util::{cpupower, Invoke};
use crate::Error;
pub const SYSTEM_REBUILD_PATH: &str = "/media/startos/config/system-rebuild";
pub const STANDBY_MODE_PATH: &str = "/media/startos/config/standby";
@@ -180,14 +195,114 @@ pub async fn init_postgres(datadir: impl AsRef<Path>) -> Result<(), Error> {
}
pub struct InitResult {
pub db: TypedPatchDb<Database>,
pub net_ctrl: PreInitNetController,
}
pub struct InitPhases {
preinit: Option<PhaseProgressTrackerHandle>,
local_auth: PhaseProgressTrackerHandle,
load_database: PhaseProgressTrackerHandle,
load_ssh_keys: PhaseProgressTrackerHandle,
start_net: PhaseProgressTrackerHandle,
mount_logs: PhaseProgressTrackerHandle,
load_ca_cert: PhaseProgressTrackerHandle,
load_wifi: PhaseProgressTrackerHandle,
init_tmp: PhaseProgressTrackerHandle,
set_governor: PhaseProgressTrackerHandle,
sync_clock: PhaseProgressTrackerHandle,
enable_zram: PhaseProgressTrackerHandle,
update_server_info: PhaseProgressTrackerHandle,
launch_service_network: PhaseProgressTrackerHandle,
run_migrations: PhaseProgressTrackerHandle,
validate_db: PhaseProgressTrackerHandle,
postinit: Option<PhaseProgressTrackerHandle>,
}
impl InitPhases {
pub fn new(handle: &FullProgressTracker) -> Self {
Self {
preinit: if Path::new("/media/startos/config/preinit.sh").exists() {
Some(handle.add_phase("Running preinit.sh".into(), Some(5)))
} else {
None
},
local_auth: handle.add_phase("Enabling local authentication".into(), Some(1)),
load_database: handle.add_phase("Loading database".into(), Some(5)),
load_ssh_keys: handle.add_phase("Loading SSH Keys".into(), Some(1)),
start_net: handle.add_phase("Starting network controller".into(), Some(1)),
mount_logs: handle.add_phase("Switching logs to write to data drive".into(), Some(1)),
load_ca_cert: handle.add_phase("Loading CA certificate".into(), Some(1)),
load_wifi: handle.add_phase("Loading WiFi configuration".into(), Some(1)),
init_tmp: handle.add_phase("Initializing temporary files".into(), Some(1)),
set_governor: handle.add_phase("Setting CPU performance profile".into(), Some(1)),
sync_clock: handle.add_phase("Synchronizing system clock".into(), Some(10)),
enable_zram: handle.add_phase("Enabling ZRAM".into(), Some(1)),
update_server_info: handle.add_phase("Updating server info".into(), Some(1)),
launch_service_network: handle.add_phase("Launching service intranet".into(), Some(10)),
run_migrations: handle.add_phase("Running migrations".into(), Some(10)),
validate_db: handle.add_phase("Validating database".into(), Some(1)),
postinit: if Path::new("/media/startos/config/postinit.sh").exists() {
Some(handle.add_phase("Running postinit.sh".into(), Some(5)))
} else {
None
},
}
}
}
pub async fn run_script<P: AsRef<Path>>(path: P, mut progress: PhaseProgressTrackerHandle) {
let script = path.as_ref();
progress.start();
if let Err(e) = async {
let script = tokio::fs::read_to_string(script).await?;
progress.set_total(script.as_bytes().iter().filter(|b| **b == b'\n').count() as u64);
let mut reader = IOHook::new(Cursor::new(script.as_bytes()));
reader.post_read(|buf| progress += buf.iter().filter(|b| **b == b'\n').count() as u64);
Command::new("/bin/bash")
.input(Some(&mut reader))
.invoke(ErrorKind::Unknown)
.await?;
Ok::<_, Error>(())
}
.await
{
tracing::error!("Error Running {}: {}", script.display(), e);
tracing::debug!("{:?}", e);
}
progress.complete();
}
#[instrument(skip_all)]
pub async fn init(cfg: &ServerConfig) -> Result<InitResult, Error> {
tokio::fs::create_dir_all("/run/embassy")
pub async fn init(
cfg: &ServerConfig,
InitPhases {
preinit,
mut local_auth,
mut load_database,
mut load_ssh_keys,
mut start_net,
mut mount_logs,
mut load_ca_cert,
mut load_wifi,
mut init_tmp,
mut set_governor,
mut sync_clock,
mut enable_zram,
mut update_server_info,
mut launch_service_network,
run_migrations,
mut validate_db,
postinit,
}: InitPhases,
) -> Result<InitResult, Error> {
if let Some(progress) = preinit {
run_script("/media/startos/config/preinit.sh", progress).await;
}
local_auth.start();
tokio::fs::create_dir_all("/run/startos")
.await
.with_ctx(|_| (crate::ErrorKind::Filesystem, "mkdir -p /run/embassy"))?;
.with_ctx(|_| (crate::ErrorKind::Filesystem, "mkdir -p /run/startos"))?;
if tokio::fs::metadata(LOCAL_AUTH_COOKIE_PATH).await.is_err() {
tokio::fs::write(
LOCAL_AUTH_COOKIE_PATH,
@@ -207,43 +322,41 @@ pub async fn init(cfg: &ServerConfig) -> Result<InitResult, Error> {
.invoke(crate::ErrorKind::Filesystem)
.await?;
}
local_auth.complete();
load_database.start();
let db = TypedPatchDb::<Database>::load_unchecked(cfg.db().await?);
let peek = db.peek().await;
load_database.complete();
tracing::info!("Opened PatchDB");
load_ssh_keys.start();
crate::ssh::sync_keys(
&peek.as_private().as_ssh_pubkeys().de()?,
SSH_AUTHORIZED_KEYS_FILE,
)
.await?;
load_ssh_keys.complete();
tracing::info!("Synced SSH Keys");
let account = AccountInfo::load(&peek)?;
let mut server_info = peek.as_public().as_server_info().de()?;
// write to ca cert store
tokio::fs::write(
"/usr/local/share/ca-certificates/startos-root-ca.crt",
account.root_ca_cert.to_pem()?,
start_net.start();
let net_ctrl = PreInitNetController::init(
db.clone(),
cfg.tor_control
.unwrap_or(SocketAddr::from(([127, 0, 0, 1], 9051))),
cfg.tor_socks.unwrap_or(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(127, 0, 0, 1),
9050,
))),
&account.hostname,
account.tor_key,
)
.await?;
Command::new("update-ca-certificates")
.invoke(crate::ErrorKind::OpenSsl)
.await?;
crate::net::wifi::synchronize_wpa_supplicant_conf(
&cfg.datadir().join("main"),
&mut server_info.wifi,
)
.await?;
tracing::info!("Synchronized WiFi");
let should_rebuild = tokio::fs::metadata(SYSTEM_REBUILD_PATH).await.is_ok()
|| &*server_info.version < &emver::Version::new(0, 3, 2, 0)
|| (*ARCH == "x86_64" && &*server_info.version < &emver::Version::new(0, 3, 4, 0));
start_net.complete();
mount_logs.start();
let log_dir = cfg.datadir().join("main/logs");
if tokio::fs::metadata(&log_dir).await.is_err() {
tokio::fs::create_dir_all(&log_dir).await?;
@@ -272,10 +385,35 @@ pub async fn init(cfg: &ServerConfig) -> Result<InitResult, Error> {
.arg("systemd-journald")
.invoke(crate::ErrorKind::Journald)
.await?;
mount_logs.complete();
tracing::info!("Mounted Logs");
let mut server_info = peek.as_public().as_server_info().de()?;
load_ca_cert.start();
// write to ca cert store
tokio::fs::write(
"/usr/local/share/ca-certificates/startos-root-ca.crt",
account.root_ca_cert.to_pem()?,
)
.await?;
Command::new("update-ca-certificates")
.invoke(crate::ErrorKind::OpenSsl)
.await?;
load_ca_cert.complete();
load_wifi.start();
crate::net::wifi::synchronize_wpa_supplicant_conf(
&cfg.datadir().join("main"),
&mut server_info.wifi,
)
.await?;
load_wifi.complete();
tracing::info!("Synchronized WiFi");
init_tmp.start();
let tmp_dir = cfg.datadir().join("package-data/tmp");
if should_rebuild && tokio::fs::metadata(&tmp_dir).await.is_ok() {
if tokio::fs::metadata(&tmp_dir).await.is_ok() {
tokio::fs::remove_dir_all(&tmp_dir).await?;
}
if tokio::fs::metadata(&tmp_dir).await.is_err() {
@@ -286,23 +424,30 @@ pub async fn init(cfg: &ServerConfig) -> Result<InitResult, Error> {
tokio::fs::remove_dir_all(&tmp_var).await?;
}
crate::disk::mount::util::bind(&tmp_var, "/var/tmp", false).await?;
init_tmp.complete();
set_governor.start();
let governor = if let Some(governor) = &server_info.governor {
if get_available_governors().await?.contains(governor) {
if cpupower::get_available_governors()
.await?
.contains(governor)
{
Some(governor)
} else {
tracing::warn!("CPU Governor \"{governor}\" Not Available");
None
}
} else {
get_preferred_governor().await?
cpupower::get_preferred_governor().await?
};
if let Some(governor) = governor {
tracing::info!("Setting CPU Governor to \"{governor}\"");
set_governor(governor).await?;
cpupower::set_governor(governor).await?;
tracing::info!("Set CPU Governor");
}
set_governor.complete();
sync_clock.start();
server_info.ntp_synced = false;
let mut not_made_progress = 0u32;
for _ in 0..1800 {
@@ -329,10 +474,15 @@ pub async fn init(cfg: &ServerConfig) -> Result<InitResult, Error> {
} else {
tracing::info!("Syncronized system clock");
}
sync_clock.complete();
enable_zram.start();
if server_info.zram {
crate::system::enable_zram().await?
}
enable_zram.complete();
update_server_info.start();
server_info.ip_info = crate::net::dhcp::init_ips().await?;
server_info.status_info = ServerStatus {
updated: false,
@@ -341,36 +491,129 @@ pub async fn init(cfg: &ServerConfig) -> Result<InitResult, Error> {
shutting_down: false,
restarting: false,
};
db.mutate(|v| {
v.as_public_mut().as_server_info_mut().ser(&server_info)?;
Ok(())
})
.await?;
update_server_info.complete();
launch_service_network.start();
Command::new("systemctl")
.arg("start")
.arg("lxc-net.service")
.invoke(ErrorKind::Lxc)
.await?;
launch_service_network.complete();
crate::version::init(&db).await?;
crate::version::init(&db, run_migrations).await?;
validate_db.start();
db.mutate(|d| {
let model = d.de()?;
d.ser(&model)
})
.await?;
validate_db.complete();
if should_rebuild {
match tokio::fs::remove_file(SYSTEM_REBUILD_PATH).await {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(e),
}?;
if let Some(progress) = postinit {
run_script("/media/startos/config/postinit.sh", progress).await;
}
tracing::info!("System initialized.");
Ok(InitResult { db })
Ok(InitResult { net_ctrl })
}
pub fn init_api<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand("logs", crate::system::logs::<InitContext>())
.subcommand(
"logs",
from_fn_async(crate::logs::cli_logs::<InitContext, Empty>).no_display(),
)
.subcommand("kernel-logs", crate::system::kernel_logs::<InitContext>())
.subcommand(
"kernel-logs",
from_fn_async(crate::logs::cli_logs::<InitContext, Empty>).no_display(),
)
.subcommand("subscribe", from_fn_async(init_progress).no_cli())
.subcommand("subscribe", from_fn_async(cli_init_progress).no_display())
}
#[derive(Debug, Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct InitProgressRes {
pub progress: FullProgress,
pub guid: Guid,
}
pub async fn init_progress(ctx: InitContext) -> Result<InitProgressRes, Error> {
let progress_tracker = ctx.progress.clone();
let progress = progress_tracker.snapshot();
let guid = Guid::new();
ctx.rpc_continuations
.add(
guid.clone(),
RpcContinuation::ws(
|mut ws| async move {
if let Err(e) = async {
let mut stream = progress_tracker.stream(Some(Duration::from_millis(100)));
while let Some(progress) = stream.next().await {
ws.send(ws::Message::Text(
serde_json::to_string(&progress)
.with_kind(ErrorKind::Serialization)?,
))
.await
.with_kind(ErrorKind::Network)?;
if progress.overall.is_complete() {
break;
}
}
ws.normal_close("complete").await?;
Ok::<_, Error>(())
}
.await
{
tracing::error!("error in init progress websocket: {e}");
tracing::debug!("{e:?}");
}
},
Duration::from_secs(30),
),
)
.await;
Ok(InitProgressRes { progress, guid })
}
pub async fn cli_init_progress(
HandlerArgs {
context: ctx,
parent_method,
method,
raw_params,
..
}: HandlerArgs<CliContext>,
) -> Result<(), Error> {
let res: InitProgressRes = from_value(
ctx.call_remote::<InitContext>(
&parent_method
.into_iter()
.chain(method.into_iter())
.join("."),
raw_params,
)
.await?,
)?;
let mut ws = ctx.ws_continuation(res.guid).await?;
let mut bar = PhasedProgressBar::new("Initializing...");
while let Some(msg) = ws.try_next().await.with_kind(ErrorKind::Network)? {
if let tokio_tungstenite::tungstenite::Message::Text(msg) = msg {
bar.update(&serde_json::from_str(&msg).with_kind(ErrorKind::Deserialization)?);
}
}
Ok(())
}

View File

@@ -6,7 +6,8 @@ use clap::builder::ValueParserFactory;
use clap::{value_parser, CommandFactory, FromArgMatches, Parser};
use color_eyre::eyre::eyre;
use emver::VersionRange;
use futures::{FutureExt, StreamExt};
use futures::StreamExt;
use imbl_value::InternedString;
use itertools::Itertools;
use patch_db::json_ptr::JsonPointer;
use reqwest::header::{HeaderMap, CONTENT_LENGTH};
@@ -29,6 +30,7 @@ use crate::s9pk::merkle_archive::source::http::HttpSource;
use crate::s9pk::S9pk;
use crate::upload::upload;
use crate::util::clap::FromStrParser;
use crate::util::net::WebSocketExt;
use crate::util::Never;
pub const PKG_ARCHIVE_DIR: &str = "package-data/archive";
@@ -152,7 +154,6 @@ pub async fn install(
.await?,
),
None, // TODO
true,
)
.await?;
@@ -171,7 +172,15 @@ pub async fn install(
Ok(())
}
#[derive(Deserialize, Serialize)]
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct SideloadParams {
#[ts(skip)]
#[serde(rename = "__auth_session")]
session: InternedString,
}
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct SideloadResponse {
pub upload: Guid,
@@ -179,8 +188,11 @@ pub struct SideloadResponse {
}
#[instrument(skip_all)]
pub async fn sideload(ctx: RpcContext) -> Result<SideloadResponse, Error> {
let (upload, file) = upload(&ctx).await?;
pub async fn sideload(
ctx: RpcContext,
SideloadParams { session }: SideloadParams,
) -> Result<SideloadResponse, Error> {
let (upload, file) = upload(&ctx, session.clone()).await?;
let (id_send, id_recv) = oneshot::channel();
let (err_send, err_recv) = oneshot::channel();
let progress = Guid::new();
@@ -194,17 +206,27 @@ pub async fn sideload(ctx: RpcContext) -> Result<SideloadResponse, Error> {
.await;
ctx.rpc_continuations.add(
progress.clone(),
RpcContinuation::ws(
Box::new(|mut ws| {
RpcContinuation::ws_authed(&ctx, session,
|mut ws| {
use axum::extract::ws::Message;
async move {
if let Err(e) = async {
let id = id_recv.await.map_err(|_| {
let id = match id_recv.await.map_err(|_| {
Error::new(
eyre!("Could not get id to watch progress"),
ErrorKind::Cancelled,
)
})?;
}).and_then(|a|a) {
Ok(a) => a,
Err(e) =>{ ws.send(Message::Text(
serde_json::to_string(&Err::<(), _>(RpcError::from(e.clone_output())))
.with_kind(ErrorKind::Serialization)?,
))
.await
.with_kind(ErrorKind::Network)?;
return Err(e);
}
};
tokio::select! {
res = async {
while let Some(_) = sub.recv().await {
@@ -242,7 +264,7 @@ pub async fn sideload(ctx: RpcContext) -> Result<SideloadResponse, Error> {
}
}
ws.close().await.with_kind(ErrorKind::Network)?;
ws.normal_close("complete").await?;
Ok::<_, Error>(())
}
@@ -252,26 +274,32 @@ pub async fn sideload(ctx: RpcContext) -> Result<SideloadResponse, Error> {
tracing::debug!("{e:?}");
}
}
.boxed()
}),
},
Duration::from_secs(600),
),
)
.await;
tokio::spawn(async move {
if let Err(e) = async {
let s9pk = S9pk::deserialize(
match S9pk::deserialize(
&file, None, // TODO
true,
)
.await?;
let _ = id_send.send(s9pk.as_manifest().id.clone());
ctx.services
.install(ctx.clone(), s9pk, None::<Never>)
.await?
.await?
.await?;
file.delete().await
.await
{
Ok(s9pk) => {
let _ = id_send.send(Ok(s9pk.as_manifest().id.clone()));
ctx.services
.install(ctx.clone(), s9pk, None::<Never>)
.await?
.await?
.await?;
file.delete().await
}
Err(e) => {
let _ = id_send.send(Err(e.clone_output()));
return Err(e);
}
}
}
.await
{

View File

@@ -1,15 +1,8 @@
pub const DEFAULT_MARKETPLACE: &str = "https://registry.start9.com";
// pub const COMMUNITY_MARKETPLACE: &str = "https://community-registry.start9.com";
pub const CAP_1_KiB: usize = 1024;
pub const CAP_1_MiB: usize = CAP_1_KiB * CAP_1_KiB;
pub const CAP_10_MiB: usize = 10 * CAP_1_MiB;
pub const HOST_IP: [u8; 4] = [172, 18, 0, 1];
pub const TARGET: &str = current_platform::CURRENT_PLATFORM;
pub use std::env::consts::ARCH;
lazy_static::lazy_static! {
pub static ref ARCH: &'static str = {
let (arch, _) = TARGET.split_once("-").unwrap();
arch
};
pub static ref PLATFORM: String = {
if let Ok(platform) = std::fs::read_to_string("/usr/lib/startos/PLATFORM.txt") {
platform
@@ -22,6 +15,15 @@ lazy_static::lazy_static! {
};
}
mod cap {
#![allow(non_upper_case_globals)]
pub const CAP_1_KiB: usize = 1024;
pub const CAP_1_MiB: usize = CAP_1_KiB * CAP_1_KiB;
pub const CAP_10_MiB: usize = 10 * CAP_1_MiB;
}
pub use cap::*;
pub mod account;
pub mod action;
pub mod auth;
@@ -79,13 +81,17 @@ use rpc_toolkit::{
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use crate::context::{CliContext, DiagnosticContext, InstallContext, RpcContext, SetupContext};
use crate::context::{
CliContext, DiagnosticContext, InitContext, InstallContext, RpcContext, SetupContext,
};
use crate::disk::fsck::RequiresReboot;
use crate::registry::context::{RegistryContext, RegistryUrlParams};
use crate::util::serde::HandlerExtSerde;
#[derive(Deserialize, Serialize, Parser, TS)]
#[serde(rename_all = "camelCase")]
#[command(rename_all = "kebab-case")]
#[ts(export)]
pub struct EchoParams {
message: String,
}
@@ -94,6 +100,20 @@ pub fn echo<C: Context>(_: C, EchoParams { message }: EchoParams) -> Result<Stri
Ok(message)
}
#[derive(Debug, Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub enum ApiState {
Error,
Initializing,
Running,
}
impl std::fmt::Display for ApiState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&self, f)
}
}
pub fn main_api<C: Context>() -> ParentHandler<C> {
let api = ParentHandler::new()
.subcommand::<C, _>("git-info", from_fn(version::git_info))
@@ -103,6 +123,12 @@ pub fn main_api<C: Context>() -> ParentHandler<C> {
.with_metadata("authenticated", Value::Bool(false))
.with_call_remote::<CliContext>(),
)
.subcommand(
"state",
from_fn(|_: RpcContext| Ok::<_, Error>(ApiState::Running))
.with_metadata("authenticated", Value::Bool(false))
.with_call_remote::<CliContext>(),
)
.subcommand("server", server::<C>())
.subcommand("package", package::<C>())
.subcommand("net", net::net::<C>())
@@ -185,11 +211,18 @@ pub fn server<C: Context>() -> ParentHandler<C> {
)
.subcommand(
"update-firmware",
from_fn_async(|_: RpcContext| firmware::update_firmware())
.with_custom_display_fn(|_handle, result| {
Ok(firmware::display_firmware_update_result(result))
})
.with_call_remote::<CliContext>(),
from_fn_async(|_: RpcContext| async {
if let Some(firmware) = firmware::check_for_firmware_update().await? {
firmware::update_firmware(firmware).await?;
Ok::<_, Error>(RequiresReboot(true))
} else {
Ok(RequiresReboot(false))
}
})
.with_custom_display_fn(|_handle, result| {
Ok(firmware::display_firmware_update_result(result))
})
.with_call_remote::<CliContext>(),
)
}
@@ -210,7 +243,12 @@ pub fn package<C: Context>() -> ParentHandler<C> {
.with_metadata("sync_db", Value::Bool(true))
.no_cli(),
)
.subcommand("sideload", from_fn_async(install::sideload).no_cli())
.subcommand(
"sideload",
from_fn_async(install::sideload)
.with_metadata("get_session", Value::Bool(true))
.no_cli(),
)
.subcommand("install", from_fn_async(install::cli_install).no_display())
.subcommand(
"uninstall",
@@ -279,9 +317,34 @@ pub fn diagnostic_api() -> ParentHandler<DiagnosticContext> {
"echo",
from_fn(echo::<DiagnosticContext>).with_call_remote::<CliContext>(),
)
.subcommand(
"state",
from_fn(|_: DiagnosticContext| Ok::<_, Error>(ApiState::Error))
.with_metadata("authenticated", Value::Bool(false))
.with_call_remote::<CliContext>(),
)
.subcommand("diagnostic", diagnostic::diagnostic::<DiagnosticContext>())
}
pub fn init_api() -> ParentHandler<InitContext> {
ParentHandler::new()
.subcommand::<InitContext, _>(
"git-info",
from_fn(version::git_info).with_metadata("authenticated", Value::Bool(false)),
)
.subcommand(
"echo",
from_fn(echo::<InitContext>).with_call_remote::<CliContext>(),
)
.subcommand(
"state",
from_fn(|_: InitContext| Ok::<_, Error>(ApiState::Initializing))
.with_metadata("authenticated", Value::Bool(false))
.with_call_remote::<CliContext>(),
)
.subcommand("init", init::init_api::<InitContext>())
}
pub fn setup_api() -> ParentHandler<SetupContext> {
ParentHandler::new()
.subcommand::<SetupContext, _>(

View File

@@ -5,7 +5,7 @@ use std::sync::{Arc, Weak};
use std::time::Duration;
use clap::builder::ValueParserFactory;
use futures::{AsyncWriteExt, FutureExt, StreamExt};
use futures::{AsyncWriteExt, StreamExt};
use imbl_value::{InOMap, InternedString};
use models::InvalidId;
use rpc_toolkit::yajrc::RpcError;
@@ -24,7 +24,7 @@ use crate::disk::mount::filesystem::bind::Bind;
use crate::disk::mount::filesystem::block_dev::BlockDev;
use crate::disk::mount::filesystem::idmapped::IdMapped;
use crate::disk::mount::filesystem::overlayfs::OverlayGuard;
use crate::disk::mount::filesystem::{MountType, ReadWrite};
use crate::disk::mount::filesystem::{MountType, ReadOnly, ReadWrite};
use crate::disk::mount::guard::{GenericMountGuard, MountGuard, TmpMountGuard};
use crate::disk::mount::util::unmount;
use crate::prelude::*;
@@ -151,7 +151,7 @@ impl LxcManager {
pub struct LxcContainer {
manager: Weak<LxcManager>,
rootfs: OverlayGuard,
rootfs: OverlayGuard<TmpMountGuard>,
pub guid: Arc<ContainerId>,
rpc_bind: TmpMountGuard,
log_mount: Option<MountGuard>,
@@ -182,12 +182,16 @@ impl LxcContainer {
.invoke(ErrorKind::Filesystem)
.await?;
let rootfs = OverlayGuard::mount(
&IdMapped::new(
BlockDev::new("/usr/lib/startos/container-runtime/rootfs.squashfs"),
0,
100000,
65536,
),
TmpMountGuard::mount(
&IdMapped::new(
BlockDev::new("/usr/lib/startos/container-runtime/rootfs.squashfs"),
0,
100000,
65536,
),
ReadOnly,
)
.await?,
&rootfs_dir,
)
.await?;
@@ -376,49 +380,46 @@ pub async fn connect(ctx: &RpcContext, container: &LxcContainer) -> Result<Guid,
.add(
guid.clone(),
RpcContinuation::ws(
Box::new(|mut ws| {
async move {
if let Err(e) = async {
loop {
match ws.next().await {
None => break,
Some(Ok(Message::Text(txt))) => {
let mut id = None;
let result = async {
let req: RpcRequest = serde_json::from_str(&txt)
.map_err(|e| RpcError {
data: Some(serde_json::Value::String(
e.to_string(),
)),
..rpc_toolkit::yajrc::PARSE_ERROR
})?;
id = req.id;
rpc.request(req.method, req.params).await
}
.await;
ws.send(Message::Text(
serde_json::to_string(&RpcResponse { id, result })
.with_kind(ErrorKind::Serialization)?,
))
.await
.with_kind(ErrorKind::Network)?;
}
Some(Ok(_)) => (),
Some(Err(e)) => {
return Err(Error::new(e, ErrorKind::Network));
|mut ws| async move {
if let Err(e) = async {
loop {
match ws.next().await {
None => break,
Some(Ok(Message::Text(txt))) => {
let mut id = None;
let result = async {
let req: RpcRequest =
serde_json::from_str(&txt).map_err(|e| RpcError {
data: Some(serde_json::Value::String(
e.to_string(),
)),
..rpc_toolkit::yajrc::PARSE_ERROR
})?;
id = req.id;
rpc.request(req.method, req.params).await
}
.await;
ws.send(Message::Text(
serde_json::to_string(&RpcResponse { id, result })
.with_kind(ErrorKind::Serialization)?,
))
.await
.with_kind(ErrorKind::Network)?;
}
Some(Ok(_)) => (),
Some(Err(e)) => {
return Err(Error::new(e, ErrorKind::Network));
}
}
Ok::<_, Error>(())
}
.await
{
tracing::error!("{e}");
tracing::debug!("{e:?}");
}
Ok::<_, Error>(())
}
.boxed()
}),
.await
{
tracing::error!("{e}");
tracing::debug!("{e:?}");
}
},
Duration::from_secs(30),
),
)

View File

@@ -23,7 +23,7 @@ use tokio::sync::Mutex;
use crate::context::RpcContext;
use crate::prelude::*;
pub const LOCAL_AUTH_COOKIE_PATH: &str = "/run/embassy/rpc.authcookie";
pub const LOCAL_AUTH_COOKIE_PATH: &str = "/run/startos/rpc.authcookie";
#[derive(Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
@@ -48,19 +48,9 @@ impl HasLoggedOutSessions {
.into_iter()
.map(|s| s.as_logout_session_id())
.collect();
ctx.open_authed_websockets
.lock()
.await
.retain(|session, sockets| {
if to_log_out.contains(session.hashed()) {
for socket in std::mem::take(sockets) {
let _ = socket.send(());
}
false
} else {
true
}
});
for sid in &to_log_out {
ctx.open_authed_continuations.kill(sid)
}
ctx.db
.mutate(|db| {
let sessions = db.as_private_mut().as_sessions_mut();

View File

@@ -1,42 +0,0 @@
use rpc_toolkit::yajrc::RpcMethod;
use rpc_toolkit::{Empty, Middleware, RpcRequest, RpcResponse};
use crate::context::DiagnosticContext;
use crate::prelude::*;
#[derive(Clone)]
pub struct DiagnosticMode {
method: Option<String>,
}
impl DiagnosticMode {
pub fn new() -> Self {
Self { method: None }
}
}
impl Middleware<DiagnosticContext> for DiagnosticMode {
type Metadata = Empty;
async fn process_rpc_request(
&mut self,
_: &DiagnosticContext,
_: Self::Metadata,
request: &mut RpcRequest,
) -> Result<(), RpcResponse> {
self.method = Some(request.method.as_str().to_owned());
Ok(())
}
async fn process_rpc_response(&mut self, _: &DiagnosticContext, response: &mut RpcResponse) {
if let Err(e) = &mut response.result {
if e.code == -32601 {
*e = Error::new(
eyre!(
"{} is not available on the Diagnostic API",
self.method.as_ref().map(|s| s.as_str()).unwrap_or_default()
),
crate::ErrorKind::DiagnosticMode,
)
.into();
}
}
}
}

View File

@@ -1,4 +1,3 @@
pub mod auth;
pub mod cors;
pub mod db;
pub mod diagnostic;

View File

@@ -3,7 +3,7 @@ use std::net::IpAddr;
use clap::Parser;
use futures::TryStreamExt;
use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler};
use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use ts_rs::TS;

View File

@@ -1,4 +1,3 @@
use imbl_value::InternedString;
use serde::{Deserialize, Serialize};
use ts_rs::TS;
@@ -11,17 +10,31 @@ use crate::prelude::*;
#[ts(export)]
pub struct BindInfo {
pub options: BindOptions,
pub assigned_lan_port: Option<u16>,
pub lan: LanInfo,
}
#[derive(Clone, Copy, Debug, Deserialize, Serialize, TS, PartialEq, Eq, PartialOrd, Ord)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct LanInfo {
pub assigned_port: Option<u16>,
pub assigned_ssl_port: Option<u16>,
}
impl BindInfo {
pub fn new(available_ports: &mut AvailablePorts, options: BindOptions) -> Result<Self, Error> {
let mut assigned_lan_port = None;
if options.add_ssl.is_some() || options.secure.is_some() {
assigned_lan_port = Some(available_ports.alloc()?);
let mut assigned_port = None;
let mut assigned_ssl_port = None;
if options.secure.is_some() {
assigned_port = Some(available_ports.alloc()?);
}
if options.add_ssl.is_some() {
assigned_ssl_port = Some(available_ports.alloc()?);
}
Ok(Self {
options,
assigned_lan_port,
lan: LanInfo {
assigned_port,
assigned_ssl_port,
},
})
}
pub fn update(
@@ -29,29 +42,38 @@ impl BindInfo {
available_ports: &mut AvailablePorts,
options: BindOptions,
) -> Result<Self, Error> {
let Self {
mut assigned_lan_port,
..
} = self;
if options.add_ssl.is_some() || options.secure.is_some() {
assigned_lan_port = if let Some(port) = assigned_lan_port.take() {
let Self { mut lan, .. } = self;
if options
.secure
.map_or(false, |s| !(s.ssl && options.add_ssl.is_some()))
// doesn't make sense to have 2 listening ports, both with ssl
{
lan.assigned_port = if let Some(port) = lan.assigned_port.take() {
Some(port)
} else {
Some(available_ports.alloc()?)
};
} else {
if let Some(port) = assigned_lan_port.take() {
if let Some(port) = lan.assigned_port.take() {
available_ports.free([port]);
}
}
Ok(Self {
options,
assigned_lan_port,
})
if options.add_ssl.is_some() {
lan.assigned_ssl_port = if let Some(port) = lan.assigned_ssl_port.take() {
Some(port)
} else {
Some(available_ports.alloc()?)
};
} else {
if let Some(port) = lan.assigned_ssl_port.take() {
available_ports.free([port]);
}
}
Ok(Self { options, lan })
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, TS)]
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, TS)]
#[ts(export)]
#[serde(rename_all = "camelCase")]
pub struct Security {
@@ -62,8 +84,6 @@ pub struct Security {
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct BindOptions {
#[ts(type = "string | null")]
pub scheme: Option<InternedString>,
pub preferred_external_port: u16,
pub add_ssl: Option<AddSslOptions>,
pub secure: Option<Security>,
@@ -73,11 +93,8 @@ pub struct BindOptions {
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct AddSslOptions {
#[ts(type = "string | null")]
pub scheme: Option<InternedString>,
pub preferred_external_port: u16,
// #[serde(default)]
// pub add_x_forwarded_headers: bool, // TODO
#[serde(default)]
pub alpn: AlpnInfo,
pub alpn: Option<AlpnInfo>,
}

View File

@@ -3,13 +3,13 @@ use std::collections::{BTreeMap, BTreeSet};
use imbl_value::InternedString;
use models::{HostId, PackageId};
use serde::{Deserialize, Serialize};
use torut::onion::{OnionAddressV3, TorSecretKeyV3};
use ts_rs::TS;
use crate::db::model::DatabaseModel;
use crate::net::forward::AvailablePorts;
use crate::net::host::address::HostAddress;
use crate::net::host::binding::{BindInfo, BindOptions};
use crate::net::service_interface::HostnameInfo;
use crate::prelude::*;
pub mod address;
@@ -23,7 +23,8 @@ pub struct Host {
pub kind: HostKind,
pub bindings: BTreeMap<u16, BindInfo>,
pub addresses: BTreeSet<HostAddress>,
pub primary: Option<HostAddress>,
/// COMPUTED: NetService::update
pub hostname_info: BTreeMap<u16, Vec<HostnameInfo>>, // internal port -> Hostnames
}
impl AsRef<Host> for Host {
fn as_ref(&self) -> &Host {
@@ -36,7 +37,7 @@ impl Host {
kind,
bindings: BTreeMap::new(),
addresses: BTreeSet::new(),
primary: None,
hostname_info: BTreeMap::new(),
}
}
}
@@ -53,9 +54,9 @@ pub enum HostKind {
#[derive(Debug, Default, Deserialize, Serialize, HasModel, TS)]
#[model = "Model<Self>"]
#[ts(export)]
pub struct HostInfo(BTreeMap<HostId, Host>);
pub struct Hosts(pub BTreeMap<HostId, Host>);
impl Map for HostInfo {
impl Map for Hosts {
type Key = HostId;
type Value = Host;
fn key_str(key: &Self::Key) -> Result<impl AsRef<str>, Error> {
@@ -75,7 +76,7 @@ pub fn host_for<'a>(
fn host_info<'a>(
db: &'a mut DatabaseModel,
package_id: &PackageId,
) -> Result<&'a mut Model<HostInfo>, Error> {
) -> Result<&'a mut Model<Hosts>, Error> {
Ok::<_, Error>(
db.as_public_mut()
.as_package_data_mut()
@@ -129,9 +130,3 @@ impl Model<Host> {
})
}
}
impl HostInfo {
pub fn get_host_primary(&self, host_id: &HostId) -> Option<HostAddress> {
self.0.get(&host_id).and_then(|h| h.primary.clone())
}
}

View File

@@ -4,7 +4,6 @@ use std::sync::{Arc, Weak};
use color_eyre::eyre::eyre;
use imbl::OrdMap;
use lazy_format::lazy_format;
use models::{HostId, OptionExt, PackageId};
use torut::onion::{OnionAddressV3, TorSecretKeyV3};
use tracing::instrument;
@@ -15,30 +14,27 @@ use crate::hostname::Hostname;
use crate::net::dns::DnsController;
use crate::net::forward::LanPortForwardController;
use crate::net::host::address::HostAddress;
use crate::net::host::binding::{AddSslOptions, BindOptions};
use crate::net::host::binding::{AddSslOptions, BindOptions, LanInfo};
use crate::net::host::{host_for, Host, HostKind};
use crate::net::service_interface::{HostnameInfo, IpHostname, OnionHostname};
use crate::net::tor::TorController;
use crate::net::vhost::{AlpnInfo, VHostController};
use crate::prelude::*;
use crate::util::serde::MaybeUtf8String;
use crate::HOST_IP;
pub struct NetController {
db: TypedPatchDb<Database>,
pub(super) tor: TorController,
pub(super) vhost: VHostController,
pub(super) dns: DnsController,
pub(super) forward: LanPortForwardController,
pub(super) os_bindings: Vec<Arc<()>>,
pub struct PreInitNetController {
pub db: TypedPatchDb<Database>,
tor: TorController,
vhost: VHostController,
os_bindings: Vec<Arc<()>>,
}
impl NetController {
impl PreInitNetController {
#[instrument(skip_all)]
pub async fn init(
db: TypedPatchDb<Database>,
tor_control: SocketAddr,
tor_socks: SocketAddr,
dns_bind: &[SocketAddr],
hostname: &Hostname,
os_tor_key: TorSecretKeyV3,
) -> Result<Self, Error> {
@@ -46,8 +42,6 @@ impl NetController {
db: db.clone(),
tor: TorController::new(tor_control, tor_socks),
vhost: VHostController::new(db),
dns: DnsController::init(dns_bind).await?,
forward: LanPortForwardController::new(),
os_bindings: Vec::new(),
};
res.add_os_bindings(hostname, os_tor_key).await?;
@@ -73,8 +67,6 @@ impl NetController {
alpn.clone(),
)
.await?;
self.os_bindings
.push(self.dns.add(None, HOST_IP.into()).await?);
// LAN IP
self.os_bindings.push(
@@ -142,6 +134,39 @@ impl NetController {
Ok(())
}
}
pub struct NetController {
db: TypedPatchDb<Database>,
pub(super) tor: TorController,
pub(super) vhost: VHostController,
pub(super) dns: DnsController,
pub(super) forward: LanPortForwardController,
pub(super) os_bindings: Vec<Arc<()>>,
}
impl NetController {
pub async fn init(
PreInitNetController {
db,
tor,
vhost,
os_bindings,
}: PreInitNetController,
dns_bind: &[SocketAddr],
) -> Result<Self, Error> {
let mut res = Self {
db,
tor,
vhost,
dns: DnsController::init(dns_bind).await?,
forward: LanPortForwardController::new(),
os_bindings,
};
res.os_bindings
.push(res.dns.add(None, HOST_IP.into()).await?);
Ok(res)
}
#[instrument(skip_all)]
pub async fn create_service(
@@ -164,7 +189,7 @@ impl NetController {
#[derive(Default, Debug)]
struct HostBinds {
lan: BTreeMap<u16, (u16, Option<AddSslOptions>, Arc<()>)>,
lan: BTreeMap<u16, (LanInfo, Option<AddSslOptions>, Vec<Arc<()>>)>,
tor: BTreeMap<OnionAddressV3, (OrdMap<u16, SocketAddr>, Vec<Arc<()>>)>,
}
@@ -209,105 +234,173 @@ impl NetService {
.await?;
self.update(id, host).await
}
pub async fn clear_bindings(&mut self) -> Result<(), Error> {
// TODO BLUJ
Ok(())
}
async fn update(&mut self, id: HostId, host: Host) -> Result<(), Error> {
dbg!(&host);
dbg!(&self.binds);
let ctrl = self.net_controller()?;
let mut hostname_info = BTreeMap::new();
let binds = {
if !self.binds.contains_key(&id) {
self.binds.insert(id.clone(), Default::default());
}
self.binds.get_mut(&id).unwrap()
};
if true
// TODO: if should listen lan
{
for (port, bind) in &host.bindings {
let old_lan_bind = binds.lan.remove(port);
let old_lan_port = old_lan_bind.as_ref().map(|(external, _, _)| *external);
let lan_bind = old_lan_bind.filter(|(external, ssl, _)| {
ssl == &bind.options.add_ssl
&& bind.assigned_lan_port.as_ref() == Some(external)
}); // only keep existing binding if relevant details match
if let Some(external) = bind.assigned_lan_port {
let new_lan_bind = if let Some(b) = lan_bind {
b
} else {
if let Some(ssl) = &bind.options.add_ssl {
let rc = ctrl
.vhost
let peek = ctrl.db.peek().await;
// LAN
let server_info = peek.as_public().as_server_info();
let ip_info = server_info.as_ip_info().de()?;
let hostname = server_info.as_hostname().de()?;
for (port, bind) in &host.bindings {
let old_lan_bind = binds.lan.remove(port);
let old_lan_port = old_lan_bind.as_ref().map(|(external, _, _)| *external);
let lan_bind = old_lan_bind
.filter(|(external, ssl, _)| ssl == &bind.options.add_ssl && bind.lan == *external); // only keep existing binding if relevant details match
if bind.lan.assigned_port.is_some() || bind.lan.assigned_ssl_port.is_some() {
let new_lan_bind = if let Some(b) = lan_bind {
b
} else {
let mut rcs = Vec::with_capacity(2);
if let Some(ssl) = &bind.options.add_ssl {
let external = bind
.lan
.assigned_ssl_port
.or_not_found("assigned ssl port")?;
rcs.push(
ctrl.vhost
.add(
None,
external,
(self.ip, *port).into(),
if bind.options.secure.as_ref().map_or(false, |s| s.ssl) {
Ok(())
if let Some(alpn) = ssl.alpn.clone() {
Err(alpn)
} else {
Err(ssl.alpn.clone())
if bind.options.secure.as_ref().map_or(false, |s| s.ssl) {
Ok(())
} else {
Err(AlpnInfo::Reflect)
}
},
)
.await?;
(*port, Some(ssl.clone()), rc)
.await?,
);
}
if let Some(security) = bind.options.secure {
if bind.options.add_ssl.is_some() && security.ssl {
// doesn't make sense to have 2 listening ports, both with ssl
} else {
let rc = ctrl.forward.add(external, (self.ip, *port).into()).await?;
(*port, None, rc)
let external =
bind.lan.assigned_port.or_not_found("assigned lan port")?;
rcs.push(ctrl.forward.add(external, (self.ip, *port).into()).await?);
}
};
binds.lan.insert(*port, new_lan_bind);
}
(bind.lan, bind.options.add_ssl.clone(), rcs)
};
let mut bind_hostname_info: Vec<HostnameInfo> =
hostname_info.remove(port).unwrap_or_default();
for (interface, ip_info) in &ip_info {
bind_hostname_info.push(HostnameInfo::Ip {
network_interface_id: interface.clone(),
public: false,
hostname: IpHostname::Local {
value: format!("{hostname}.local"),
port: new_lan_bind.0.assigned_port,
ssl_port: new_lan_bind.0.assigned_ssl_port,
},
});
if let Some(ipv4) = ip_info.ipv4 {
bind_hostname_info.push(HostnameInfo::Ip {
network_interface_id: interface.clone(),
public: false,
hostname: IpHostname::Ipv4 {
value: ipv4,
port: new_lan_bind.0.assigned_port,
ssl_port: new_lan_bind.0.assigned_ssl_port,
},
});
}
if let Some(ipv6) = ip_info.ipv6 {
bind_hostname_info.push(HostnameInfo::Ip {
network_interface_id: interface.clone(),
public: false,
hostname: IpHostname::Ipv6 {
value: ipv6,
port: new_lan_bind.0.assigned_port,
ssl_port: new_lan_bind.0.assigned_ssl_port,
},
});
}
}
if let Some(external) = old_lan_port {
hostname_info.insert(*port, bind_hostname_info);
binds.lan.insert(*port, new_lan_bind);
}
if let Some(lan) = old_lan_port {
if let Some(external) = lan.assigned_ssl_port {
ctrl.vhost.gc(None, external).await?;
}
if let Some(external) = lan.assigned_port {
ctrl.forward.gc(external).await?;
}
}
let mut removed = BTreeSet::new();
let mut removed_ssl = BTreeSet::new();
binds.lan.retain(|internal, (external, ssl, _)| {
if host.bindings.contains_key(internal) {
true
} else {
if ssl.is_some() {
removed_ssl.insert(*external);
} else {
removed.insert(*external);
}
false
}
});
for external in removed {
ctrl.forward.gc(external).await?;
}
let mut removed = BTreeSet::new();
binds.lan.retain(|internal, (external, _, _)| {
if host.bindings.contains_key(internal) {
true
} else {
removed.insert(*external);
false
}
for external in removed_ssl {
});
for lan in removed {
if let Some(external) = lan.assigned_ssl_port {
ctrl.vhost.gc(None, external).await?;
}
if let Some(external) = lan.assigned_port {
ctrl.forward.gc(external).await?;
}
}
let tor_binds: OrdMap<u16, SocketAddr> = host
.bindings
.iter()
.flat_map(|(internal, info)| {
let non_ssl = (
info.options.preferred_external_port,
SocketAddr::from((self.ip, *internal)),
struct TorHostnamePorts {
non_ssl: Option<u16>,
ssl: Option<u16>,
}
let mut tor_hostname_ports = BTreeMap::<u16, TorHostnamePorts>::new();
let mut tor_binds = OrdMap::<u16, SocketAddr>::new();
for (internal, info) in &host.bindings {
tor_binds.insert(
info.options.preferred_external_port,
SocketAddr::from((self.ip, *internal)),
);
if let (Some(ssl), Some(ssl_internal)) =
(&info.options.add_ssl, info.lan.assigned_ssl_port)
{
tor_binds.insert(
ssl.preferred_external_port,
SocketAddr::from(([127, 0, 0, 1], ssl_internal)),
);
if let (Some(ssl), Some(ssl_internal)) =
(&info.options.add_ssl, info.assigned_lan_port)
{
itertools::Either::Left(
[
(
ssl.preferred_external_port,
SocketAddr::from(([127, 0, 0, 1], ssl_internal)),
),
non_ssl,
]
.into_iter(),
)
} else {
itertools::Either::Right([non_ssl].into_iter())
}
})
.collect();
tor_hostname_ports.insert(
*internal,
TorHostnamePorts {
non_ssl: Some(info.options.preferred_external_port)
.filter(|p| *p != ssl.preferred_external_port),
ssl: Some(ssl.preferred_external_port),
},
);
} else {
tor_hostname_ports.insert(
*internal,
TorHostnamePorts {
non_ssl: Some(info.options.preferred_external_port),
ssl: None,
},
);
}
}
let mut keep_tor_addrs = BTreeSet::new();
for addr in match host.kind {
HostKind::Multi => {
@@ -324,13 +417,10 @@ impl NetService {
let new_tor_bind = if let Some(tor_bind) = tor_bind {
tor_bind
} else {
let key = ctrl
.db
.peek()
.await
.into_private()
.into_key_store()
.into_onion()
let key = peek
.as_private()
.as_key_store()
.as_onion()
.get_key(address)?;
let rcs = ctrl
.tor
@@ -338,6 +428,18 @@ impl NetService {
.await?;
(tor_binds.clone(), rcs)
};
for (internal, ports) in &tor_hostname_ports {
let mut bind_hostname_info =
hostname_info.remove(internal).unwrap_or_default();
bind_hostname_info.push(HostnameInfo::Onion {
hostname: OnionHostname {
value: address.to_string(),
port: ports.non_ssl,
ssl_port: ports.ssl,
},
});
hostname_info.insert(*internal, bind_hostname_info);
}
binds.tor.insert(address.clone(), new_tor_bind);
}
}
@@ -347,6 +449,14 @@ impl NetService {
ctrl.tor.gc(Some(addr.clone()), None).await?;
}
}
self.net_controller()?
.db
.mutate(|db| {
host_for(db, &self.id, &id, host.kind)?
.as_hostname_info_mut()
.ser(&hostname_info)
})
.await?;
Ok(())
}
@@ -355,12 +465,13 @@ impl NetService {
let mut errors = ErrorCollection::new();
if let Some(ctrl) = Weak::upgrade(&self.controller) {
for (_, binds) in std::mem::take(&mut self.binds) {
for (_, (external, ssl, rc)) in binds.lan {
for (_, (lan, _, rc)) in binds.lan {
drop(rc);
if ssl.is_some() {
errors.handle(ctrl.vhost.gc(None, external).await);
} else {
errors.handle(ctrl.forward.gc(external).await);
if let Some(external) = lan.assigned_ssl_port {
ctrl.vhost.gc(None, external).await?;
}
if let Some(external) = lan.assigned_port {
ctrl.forward.gc(external).await?;
}
}
for (addr, (_, rcs)) in binds.tor {
@@ -384,12 +495,12 @@ impl NetService {
self.ip
}
pub fn get_ext_port(&self, host_id: HostId, internal_port: u16) -> Result<u16, Error> {
pub fn get_ext_port(&self, host_id: HostId, internal_port: u16) -> Result<LanInfo, Error> {
let host_id_binds = self.binds.get_key_value(&host_id);
match host_id_binds {
Some((_, binds)) => {
if let Some(ext_port_info) = binds.lan.get(&internal_port) {
Ok(ext_port_info.0)
if let Some((lan, _, _)) = binds.lan.get(&internal_port) {
Ok(*lan)
} else {
Err(Error::new(
eyre!(

View File

@@ -0,0 +1,11 @@
<html>
<head>
<title>StartOS: Loading...</title>
<script>
setTimeout(window.location.reload, 1000)
</script>
</head>
<body>
Loading...
</body>
</html>

View File

@@ -1,51 +1,32 @@
use std::net::{Ipv4Addr, Ipv6Addr};
use imbl_value::InternedString;
use models::{HostId, ServiceInterfaceId};
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use crate::net::host::binding::BindOptions;
use crate::net::host::HostKind;
use crate::prelude::*;
#[derive(Clone, Debug, Deserialize, Serialize, TS)]
#[ts(export)]
#[serde(rename_all = "camelCase")]
pub struct ServiceInterfaceWithHostInfo {
#[serde(flatten)]
pub service_interface: ServiceInterface,
pub host_info: ExportedHostInfo,
}
#[derive(Clone, Debug, Deserialize, Serialize, TS)]
#[ts(export)]
#[serde(rename_all = "camelCase")]
pub struct ExportedHostInfo {
pub id: HostId,
pub kind: HostKind,
pub hostnames: Vec<ExportedHostnameInfo>,
}
use crate::net::host::address::HostAddress;
#[derive(Clone, Debug, Deserialize, Serialize, TS)]
#[ts(export)]
#[serde(rename_all = "camelCase")]
#[serde(rename_all_fields = "camelCase")]
#[serde(tag = "kind")]
pub enum ExportedHostnameInfo {
pub enum HostnameInfo {
Ip {
network_interface_id: String,
public: bool,
hostname: ExportedIpHostname,
hostname: IpHostname,
},
Onion {
hostname: ExportedOnionHostname,
hostname: OnionHostname,
},
}
#[derive(Clone, Debug, Deserialize, Serialize, TS)]
#[ts(export)]
#[serde(rename_all = "camelCase")]
pub struct ExportedOnionHostname {
pub struct OnionHostname {
pub value: String,
pub port: Option<u16>,
pub ssl_port: Option<u16>,
@@ -56,7 +37,7 @@ pub struct ExportedOnionHostname {
#[serde(rename_all = "camelCase")]
#[serde(rename_all_fields = "camelCase")]
#[serde(tag = "kind")]
pub enum ExportedIpHostname {
pub enum IpHostname {
Ipv4 {
value: Ipv4Addr,
port: Option<u16>,
@@ -110,6 +91,10 @@ pub enum ServiceInterfaceType {
pub struct AddressInfo {
pub username: Option<String>,
pub host_id: HostId,
pub bind_options: BindOptions,
pub internal_port: u16,
#[ts(type = "string | null")]
pub scheme: Option<InternedString>,
#[ts(type = "string | null")]
pub ssl_scheme: Option<InternedString>,
pub suffix: String,
}

View File

@@ -1,4 +1,3 @@
use std::fs::Metadata;
use std::future::Future;
use std::path::{Path, PathBuf};
use std::time::UNIX_EPOCH;
@@ -13,25 +12,26 @@ use digest::Digest;
use futures::future::ready;
use http::header::ACCEPT_ENCODING;
use http::request::Parts as RequestParts;
use http::{HeaderMap, Method, StatusCode};
use http::{Method, StatusCode};
use imbl_value::InternedString;
use include_dir::Dir;
use new_mime_guess::MimeGuess;
use openssl::hash::MessageDigest;
use openssl::x509::X509;
use rpc_toolkit::Server;
use rpc_toolkit::{Context, HttpServer, Server};
use tokio::fs::File;
use tokio::io::BufReader;
use tokio_util::io::ReaderStream;
use crate::context::{DiagnosticContext, InstallContext, RpcContext, SetupContext};
use crate::db::subscribe;
use crate::context::{DiagnosticContext, InitContext, InstallContext, RpcContext, SetupContext};
use crate::hostname::Hostname;
use crate::middleware::auth::{Auth, HasValidSession};
use crate::middleware::cors::Cors;
use crate::middleware::db::SyncDb;
use crate::middleware::diagnostic::DiagnosticMode;
use crate::rpc_continuations::Guid;
use crate::{diagnostic_api, install_api, main_api, setup_api, Error, ErrorKind, ResultExt};
use crate::rpc_continuations::{Guid, RpcContinuations};
use crate::{
diagnostic_api, init_api, install_api, main_api, setup_api, Error, ErrorKind, ResultExt,
};
const NOT_FOUND: &[u8] = b"Not Found";
const METHOD_NOT_ALLOWED: &[u8] = b"Method Not Allowed";
@@ -49,7 +49,6 @@ const PROXY_STRIP_HEADERS: &[&str] = &["cookie", "host", "origin", "referer", "u
#[derive(Clone)]
pub enum UiMode {
Setup,
Diag,
Install,
Main,
}
@@ -58,128 +57,46 @@ impl UiMode {
fn path(&self, path: &str) -> PathBuf {
match self {
Self::Setup => Path::new("setup-wizard").join(path),
Self::Diag => Path::new("diagnostic-ui").join(path),
Self::Install => Path::new("install-wizard").join(path),
Self::Main => Path::new("ui").join(path),
}
}
}
pub fn setup_ui_file_router(ctx: SetupContext) -> Router {
Router::new()
.route_service(
"/rpc/*path",
post(Server::new(move || ready(Ok(ctx.clone())), setup_api()).middleware(Cors::new())),
)
.fallback(any(|request: Request| async move {
alt_ui(request, UiMode::Setup)
.await
.unwrap_or_else(server_error)
}))
}
pub fn diag_ui_file_router(ctx: DiagnosticContext) -> Router {
pub fn rpc_router<C: Context + Clone + AsRef<RpcContinuations>>(
ctx: C,
server: HttpServer<C>,
) -> Router {
Router::new()
.route("/rpc/*path", post(server))
.route(
"/rpc/*path",
post(
Server::new(move || ready(Ok(ctx.clone())), diagnostic_api())
.middleware(Cors::new())
.middleware(DiagnosticMode::new()),
),
)
.fallback(any(|request: Request| async move {
alt_ui(request, UiMode::Diag)
.await
.unwrap_or_else(server_error)
}))
}
pub fn install_ui_file_router(ctx: InstallContext) -> Router {
Router::new()
.route("/rpc/*path", {
let ctx = ctx.clone();
post(Server::new(move || ready(Ok(ctx.clone())), install_api()).middleware(Cors::new()))
})
.fallback(any(|request: Request| async move {
alt_ui(request, UiMode::Install)
.await
.unwrap_or_else(server_error)
}))
}
pub fn main_ui_server_router(ctx: RpcContext) -> Router {
Router::new()
.route("/rpc/*path", {
let ctx = ctx.clone();
post(
Server::new(move || ready(Ok(ctx.clone())), main_api::<RpcContext>())
.middleware(Cors::new())
.middleware(Auth::new())
.middleware(SyncDb::new()),
)
})
.route(
"/ws/db",
any({
let ctx = ctx.clone();
move |headers: HeaderMap, ws: x::WebSocketUpgrade| async move {
subscribe(ctx, headers, ws)
.await
.unwrap_or_else(server_error)
}
}),
)
.route(
"/ws/rpc/*path",
"/ws/rpc/:guid",
get({
let ctx = ctx.clone();
move |x::Path(path): x::Path<String>,
move |x::Path(guid): x::Path<Guid>,
ws: axum::extract::ws::WebSocketUpgrade| async move {
match Guid::from(&path) {
None => {
tracing::debug!("No Guid Path");
bad_request()
}
Some(guid) => match ctx.rpc_continuations.get_ws_handler(&guid).await {
Some(cont) => ws.on_upgrade(cont),
_ => not_found(),
},
match AsRef::<RpcContinuations>::as_ref(&ctx).get_ws_handler(&guid).await {
Some(cont) => ws.on_upgrade(cont),
_ => not_found(),
}
}
}),
)
.route(
"/rest/rpc/*path",
"/rest/rpc/:guid",
any({
let ctx = ctx.clone();
move |request: x::Request| async move {
let path = request
.uri()
.path()
.strip_prefix("/rest/rpc/")
.unwrap_or_default();
match Guid::from(&path) {
None => {
tracing::debug!("No Guid Path");
bad_request()
}
Some(guid) => match ctx.rpc_continuations.get_rest_handler(&guid).await {
None => not_found(),
Some(cont) => cont(request).await.unwrap_or_else(server_error),
},
move |x::Path(guid): x::Path<Guid>, request: x::Request| async move {
match AsRef::<RpcContinuations>::as_ref(&ctx).get_rest_handler(&guid).await {
None => not_found(),
Some(cont) => cont(request).await.unwrap_or_else(server_error),
}
}
}),
)
.fallback(any(move |request: Request| async move {
main_start_os_ui(request, ctx)
.await
.unwrap_or_else(server_error)
}))
}
async fn alt_ui(req: Request, ui_mode: UiMode) -> Result<Response, Error> {
fn serve_ui(req: Request, ui_mode: UiMode) -> Result<Response, Error> {
let (request_parts, _body) = req.into_parts();
match &request_parts.method {
&Method::GET => {
@@ -196,9 +113,7 @@ async fn alt_ui(req: Request, ui_mode: UiMode) -> Result<Response, Error> {
.or_else(|| EMBEDDED_UIS.get_file(&*ui_mode.path("index.html")));
if let Some(file) = file {
FileData::from_embedded(&request_parts, file)
.into_response(&request_parts)
.await
FileData::from_embedded(&request_parts, file).into_response(&request_parts)
} else {
Ok(not_found())
}
@@ -207,6 +122,75 @@ async fn alt_ui(req: Request, ui_mode: UiMode) -> Result<Response, Error> {
}
}
pub fn setup_ui_router(ctx: SetupContext) -> Router {
rpc_router(
ctx.clone(),
Server::new(move || ready(Ok(ctx.clone())), setup_api()).middleware(Cors::new()),
)
.fallback(any(|request: Request| async move {
serve_ui(request, UiMode::Setup).unwrap_or_else(server_error)
}))
}
pub fn diagnostic_ui_router(ctx: DiagnosticContext) -> Router {
rpc_router(
ctx.clone(),
Server::new(move || ready(Ok(ctx.clone())), diagnostic_api()).middleware(Cors::new()),
)
.fallback(any(|request: Request| async move {
serve_ui(request, UiMode::Main).unwrap_or_else(server_error)
}))
}
pub fn install_ui_router(ctx: InstallContext) -> Router {
rpc_router(
ctx.clone(),
Server::new(move || ready(Ok(ctx.clone())), install_api()).middleware(Cors::new()),
)
.fallback(any(|request: Request| async move {
serve_ui(request, UiMode::Install).unwrap_or_else(server_error)
}))
}
pub fn init_ui_router(ctx: InitContext) -> Router {
rpc_router(
ctx.clone(),
Server::new(move || ready(Ok(ctx.clone())), init_api()).middleware(Cors::new()),
)
.fallback(any(|request: Request| async move {
serve_ui(request, UiMode::Main).unwrap_or_else(server_error)
}))
}
pub fn main_ui_router(ctx: RpcContext) -> Router {
rpc_router(
ctx.clone(),
Server::new(move || ready(Ok(ctx.clone())), main_api::<RpcContext>())
.middleware(Cors::new())
.middleware(Auth::new())
.middleware(SyncDb::new()),
)
// TODO: cert
.fallback(any(|request: Request| async move {
serve_ui(request, UiMode::Main).unwrap_or_else(server_error)
}))
}
pub fn refresher() -> Router {
Router::new().fallback(get(|request: Request| async move {
let res = include_bytes!("./refresher.html");
FileData {
data: Body::from(&res[..]),
e_tag: None,
encoding: None,
len: Some(res.len() as u64),
mime: Some("text/html".into()),
}
.into_response(&request.into_parts().0)
.unwrap_or_else(server_error)
}))
}
async fn if_authorized<
F: FnOnce() -> Fut,
Fut: Future<Output = Result<Response, Error>> + Send + Sync,
@@ -223,89 +207,6 @@ async fn if_authorized<
}
}
async fn main_start_os_ui(req: Request, ctx: RpcContext) -> Result<Response, Error> {
let (request_parts, _body) = req.into_parts();
match (
&request_parts.method,
request_parts
.uri
.path()
.strip_prefix('/')
.unwrap_or(request_parts.uri.path())
.split_once('/'),
) {
(&Method::GET, Some(("public", path))) => {
todo!("pull directly from s9pk")
}
(&Method::GET, Some(("proxy", target))) => {
if_authorized(&ctx, &request_parts, || async {
let target = urlencoding::decode(target)?;
let res = ctx
.client
.get(target.as_ref())
.headers(
request_parts
.headers
.iter()
.filter(|(h, _)| {
!PROXY_STRIP_HEADERS
.iter()
.any(|bad| h.as_str().eq_ignore_ascii_case(bad))
})
.flat_map(|(h, v)| {
Some((
reqwest::header::HeaderName::from_lowercase(
h.as_str().as_bytes(),
)
.ok()?,
reqwest::header::HeaderValue::from_bytes(v.as_bytes()).ok()?,
))
})
.collect(),
)
.send()
.await
.with_kind(crate::ErrorKind::Network)?;
let mut hres = Response::builder().status(res.status().as_u16());
for (h, v) in res.headers().clone() {
if let Some(h) = h {
hres = hres.header(h.to_string(), v.as_bytes());
}
}
hres.body(Body::from_stream(res.bytes_stream()))
.with_kind(crate::ErrorKind::Network)
})
.await
}
(&Method::GET, Some(("eos", "local.crt"))) => {
let account = ctx.account.read().await;
cert_send(&account.root_ca_cert, &account.hostname)
}
(&Method::GET, _) => {
let uri_path = UiMode::Main.path(
request_parts
.uri
.path()
.strip_prefix('/')
.unwrap_or(request_parts.uri.path()),
);
let file = EMBEDDED_UIS
.get_file(&*uri_path)
.or_else(|| EMBEDDED_UIS.get_file(&*UiMode::Main.path("index.html")));
if let Some(file) = file {
FileData::from_embedded(&request_parts, file)
.into_response(&request_parts)
.await
} else {
Ok(not_found())
}
}
_ => Ok(method_not_allowed()),
}
}
pub fn unauthorized(err: Error, path: &str) -> Response {
tracing::warn!("unauthorized for {} @{:?}", err, path);
tracing::debug!("{:?}", err);
@@ -373,8 +274,8 @@ struct FileData {
data: Body,
len: Option<u64>,
encoding: Option<&'static str>,
e_tag: String,
mime: Option<String>,
e_tag: Option<String>,
mime: Option<InternedString>,
}
impl FileData {
fn from_embedded(req: &RequestParts, file: &'static include_dir::File<'static>) -> Self {
@@ -407,10 +308,23 @@ impl FileData {
len: Some(data.len() as u64),
encoding,
data: data.into(),
e_tag: e_tag(path, None),
e_tag: file.metadata().map(|metadata| {
e_tag(
path,
format!(
"{}",
metadata
.modified()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or_else(|e| e.duration().as_secs() as i64 * -1),
)
.as_bytes(),
)
}),
mime: MimeGuess::from_path(path)
.first()
.map(|m| m.essence_str().to_owned()),
.map(|m| m.essence_str().into()),
}
}
@@ -434,7 +348,18 @@ impl FileData {
.await
.with_ctx(|_| (ErrorKind::Filesystem, path.display().to_string()))?;
let e_tag = e_tag(path, Some(&metadata));
let e_tag = Some(e_tag(
path,
format!(
"{}",
metadata
.modified()?
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or_else(|e| e.duration().as_secs() as i64 * -1)
)
.as_bytes(),
));
let (len, data) = if encoding == Some("gzip") {
(
@@ -455,16 +380,18 @@ impl FileData {
e_tag,
mime: MimeGuess::from_path(path)
.first()
.map(|m| m.essence_str().to_owned()),
.map(|m| m.essence_str().into()),
})
}
async fn into_response(self, req: &RequestParts) -> Result<Response, Error> {
fn into_response(self, req: &RequestParts) -> Result<Response, Error> {
let mut builder = Response::builder();
if let Some(mime) = self.mime {
builder = builder.header(http::header::CONTENT_TYPE, &*mime);
}
builder = builder.header(http::header::ETAG, &*self.e_tag);
if let Some(e_tag) = &self.e_tag {
builder = builder.header(http::header::ETAG, &**e_tag);
}
builder = builder.header(
http::header::CACHE_CONTROL,
"public, max-age=21000000, immutable",
@@ -481,11 +408,12 @@ impl FileData {
builder = builder.header(http::header::CONNECTION, "keep-alive");
}
if req
.headers
.get("if-none-match")
.and_then(|h| h.to_str().ok())
== Some(self.e_tag.as_ref())
if self.e_tag.is_some()
&& req
.headers
.get("if-none-match")
.and_then(|h| h.to_str().ok())
== self.e_tag.as_deref()
{
builder = builder.status(StatusCode::NOT_MODIFIED);
builder.body(Body::empty())
@@ -503,21 +431,14 @@ impl FileData {
}
}
fn e_tag(path: &Path, metadata: Option<&Metadata>) -> String {
lazy_static::lazy_static! {
static ref INSTANCE_NONCE: u64 = rand::random();
}
fn e_tag(path: &Path, modified: impl AsRef<[u8]>) -> String {
let mut hasher = sha2::Sha256::new();
hasher.update(format!("{:?}", path).as_bytes());
if let Some(modified) = metadata.and_then(|m| m.modified().ok()) {
hasher.update(
format!(
"{}",
modified
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
)
.as_bytes(),
);
}
hasher.update(modified.as_ref());
let res = hasher.finalize();
format!(
"\"{}\"",

View File

@@ -112,24 +112,6 @@ pub async fn find_eth_iface() -> Result<String, Error> {
))
}
#[pin_project::pin_project]
pub struct SingleAccept<T>(Option<T>);
impl<T> SingleAccept<T> {
pub fn new(conn: T) -> Self {
Self(Some(conn))
}
}
// impl<T> axum_server::accept::Accept for SingleAccept<T> {
// type Conn = T;
// type Error = Infallible;
// fn poll_accept(
// self: std::pin::Pin<&mut Self>,
// _cx: &mut std::task::Context<'_>,
// ) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
// std::task::Poll::Ready(self.project().0.take().map(Ok))
// }
// }
pub struct TcpListeners {
listeners: Vec<TcpListener>,
}

View File

@@ -1,10 +1,15 @@
use std::collections::BTreeMap;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use std::sync::{Arc, Weak};
use std::time::Duration;
use axum::body::Body;
use axum::extract::Request;
use axum::response::Response;
use color_eyre::eyre::eyre;
use helpers::NonDetachingJoinHandle;
use http::Uri;
use imbl_value::InternedString;
use models::ResultExt;
use serde::{Deserialize, Serialize};
@@ -20,8 +25,9 @@ use tracing::instrument;
use ts_rs::TS;
use crate::db::model::Database;
use crate::net::static_server::server_error;
use crate::prelude::*;
use crate::util::io::{BackTrackingReader, TimeoutStream};
use crate::util::io::BackTrackingReader;
use crate::util::serde::MaybeUtf8String;
// not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353
@@ -113,8 +119,16 @@ impl VHostServer {
loop {
match listener.accept().await {
Ok((stream, _)) => {
let stream =
Box::pin(TimeoutStream::new(stream, Duration::from_secs(300)));
if let Err(e) = socket2::SockRef::from(&stream).set_tcp_keepalive(
&socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(900))
.with_interval(Duration::from_secs(60))
.with_retries(5),
) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
let mut stream = BackTrackingReader::new(stream);
stream.start_buffering();
let mapping = mapping.clone();
@@ -129,38 +143,39 @@ impl VHostServer {
{
Ok(a) => a,
Err(_) => {
// stream.rewind();
// return hyper::server::Server::builder(
// SingleAccept::new(stream),
// )
// .serve(make_service_fn(|_| async {
// Ok::<_, Infallible>(service_fn(|req| async move {
// let host = req
// .headers()
// .get(http::header::HOST)
// .and_then(|host| host.to_str().ok());
// let uri = Uri::from_parts({
// let mut parts =
// req.uri().to_owned().into_parts();
// parts.authority = host
// .map(FromStr::from_str)
// .transpose()?;
// parts
// })?;
// Response::builder()
// .status(
// http::StatusCode::TEMPORARY_REDIRECT,
// )
// .header(
// http::header::LOCATION,
// uri.to_string(),
// )
// .body(Body::default())
// }))
// }))
// .await
// .with_kind(crate::ErrorKind::Network);
todo!()
stream.rewind();
return hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
.serve_connection(
hyper_util::rt::TokioIo::new(stream),
hyper_util::service::TowerToHyperService::new(axum::Router::new().fallback(
axum::routing::method_routing::any(move |req: Request| async move {
match async move {
let host = req
.headers()
.get(http::header::HOST)
.and_then(|host| host.to_str().ok());
let uri = Uri::from_parts({
let mut parts = req.uri().to_owned().into_parts();
parts.authority = host.map(FromStr::from_str).transpose()?;
parts
})?;
Response::builder()
.status(http::StatusCode::TEMPORARY_REDIRECT)
.header(http::header::LOCATION, uri.to_string())
.body(Body::default())
}.await {
Ok(a) => a,
Err(e) => {
tracing::warn!("Error redirecting http request on ssl port: {e}");
tracing::error!("{e:?}");
server_error(Error::new(e, ErrorKind::Network))
}
}
}),
)),
)
.await
.map_err(|e| Error::new(color_eyre::eyre::Report::msg(e), ErrorKind::Network));
}
};
let target_name =
@@ -205,7 +220,7 @@ impl VHostServer {
.into_entries()?
.into_iter()
.flat_map(|(_, ips)| [
ips.as_ipv4().de().map(|ip| ip.map(IpAddr::V4)),
ips.as_ipv4().de().map(|ip| ip.map(IpAddr::V4)),
ips.as_ipv6().de().map(|ip| ip.map(IpAddr::V6))
])
.filter_map(|a| a.transpose())

View File

@@ -1,23 +1,84 @@
use std::convert::Infallible;
use std::net::SocketAddr;
use std::task::Poll;
use std::time::Duration;
use axum::extract::Request;
use axum::Router;
use axum_server::Handle;
use bytes::Bytes;
use futures::future::ready;
use futures::FutureExt;
use helpers::NonDetachingJoinHandle;
use tokio::sync::oneshot;
use tokio::sync::{oneshot, watch};
use crate::context::{DiagnosticContext, InstallContext, RpcContext, SetupContext};
use crate::context::{DiagnosticContext, InitContext, InstallContext, RpcContext, SetupContext};
use crate::net::static_server::{
diag_ui_file_router, install_ui_file_router, main_ui_server_router, setup_ui_file_router,
diagnostic_ui_router, init_ui_router, install_ui_router, main_ui_router, refresher,
setup_ui_router,
};
use crate::Error;
use crate::prelude::*;
#[derive(Clone)]
pub struct SwappableRouter(watch::Sender<Router>);
impl SwappableRouter {
pub fn new(router: Router) -> Self {
Self(watch::channel(router).0)
}
pub fn swap(&self, router: Router) {
let _ = self.0.send_replace(router);
}
}
#[derive(Clone)]
pub struct SwappableRouterService(watch::Receiver<Router>);
impl<B> tower_service::Service<Request<B>> for SwappableRouterService
where
B: axum::body::HttpBody<Data = Bytes> + Send + 'static,
B::Error: Into<axum::BoxError>,
{
type Response = <Router as tower_service::Service<Request<B>>>::Response;
type Error = <Router as tower_service::Service<Request<B>>>::Error;
type Future = <Router as tower_service::Service<Request<B>>>::Future;
#[inline]
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut changed = self.0.changed().boxed();
if changed.poll_unpin(cx).is_ready() {
return Poll::Ready(Ok(()));
}
drop(changed);
tower_service::Service::<Request<B>>::poll_ready(&mut self.0.borrow().clone(), cx)
}
fn call(&mut self, req: Request<B>) -> Self::Future {
self.0.borrow().clone().call(req)
}
}
impl<T> tower_service::Service<T> for SwappableRouter {
type Response = SwappableRouterService;
type Error = Infallible;
type Future = futures::future::Ready<Result<Self::Response, Self::Error>>;
#[inline]
fn poll_ready(
&mut self,
_: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: T) -> Self::Future {
ready(Ok(SwappableRouterService(self.0.subscribe())))
}
}
pub struct WebServer {
shutdown: oneshot::Sender<()>,
router: SwappableRouter,
thread: NonDetachingJoinHandle<()>,
}
impl WebServer {
pub fn new(bind: SocketAddr, router: Router) -> Self {
pub fn new(bind: SocketAddr) -> Self {
let router = SwappableRouter::new(refresher());
let thread_router = router.clone();
let (shutdown, shutdown_recv) = oneshot::channel();
let thread = NonDetachingJoinHandle::from(tokio::spawn(async move {
let handle = Handle::new();
@@ -25,14 +86,18 @@ impl WebServer {
server.http_builder().http1().preserve_header_case(true);
server.http_builder().http1().title_case_headers(true);
if let (Err(e), _) = tokio::join!(server.serve(router.into_make_service()), async {
if let (Err(e), _) = tokio::join!(server.serve(thread_router), async {
let _ = shutdown_recv.await;
handle.graceful_shutdown(Some(Duration::from_secs(0)));
}) {
tracing::error!("Spawning hyper server error: {}", e);
}
}));
Self { shutdown, thread }
Self {
shutdown,
router,
thread,
}
}
pub async fn shutdown(self) {
@@ -40,19 +105,27 @@ impl WebServer {
self.thread.await.unwrap()
}
pub fn main(bind: SocketAddr, ctx: RpcContext) -> Result<Self, Error> {
Ok(Self::new(bind, main_ui_server_router(ctx)))
pub fn serve_router(&mut self, router: Router) {
self.router.swap(router)
}
pub fn setup(bind: SocketAddr, ctx: SetupContext) -> Result<Self, Error> {
Ok(Self::new(bind, setup_ui_file_router(ctx)))
pub fn serve_main(&mut self, ctx: RpcContext) {
self.serve_router(main_ui_router(ctx))
}
pub fn diagnostic(bind: SocketAddr, ctx: DiagnosticContext) -> Result<Self, Error> {
Ok(Self::new(bind, diag_ui_file_router(ctx)))
pub fn serve_setup(&mut self, ctx: SetupContext) {
self.serve_router(setup_ui_router(ctx))
}
pub fn install(bind: SocketAddr, ctx: InstallContext) -> Result<Self, Error> {
Ok(Self::new(bind, install_ui_file_router(ctx)))
pub fn serve_diagnostic(&mut self, ctx: DiagnosticContext) {
self.serve_router(diagnostic_ui_router(ctx))
}
pub fn serve_install(&mut self, ctx: InstallContext) {
self.serve_router(install_ui_router(ctx))
}
pub fn serve_init(&mut self, ctx: InitContext) {
self.serve_router(init_ui_router(ctx))
}
}

View File

@@ -87,7 +87,7 @@ pub async fn partition(disk: &DiskInfo, overwrite: bool) -> Result<OsPartitionIn
gpt.add_partition(
"root",
15 * 1024 * 1024 * 1024,
match *crate::ARCH {
match crate::ARCH {
"x86_64" => gpt::partition_types::LINUX_ROOT_X64,
"aarch64" => gpt::partition_types::LINUX_ROOT_ARM_64,
_ => gpt::partition_types::LINUX_FS,

View File

@@ -366,7 +366,7 @@ pub async fn execute<C: Context>(
if tokio::fs::metadata("/sys/firmware/efi").await.is_err() {
install.arg("--target=i386-pc");
} else {
match *ARCH {
match ARCH {
"x86_64" => install.arg("--target=x86_64-efi"),
"aarch64" => install.arg("--target=arm64-efi"),
_ => &mut install,

View File

@@ -1,14 +1,16 @@
use std::panic::UnwindSafe;
use std::sync::Arc;
use std::time::Duration;
use futures::Future;
use futures::future::pending;
use futures::stream::BoxStream;
use futures::{Future, FutureExt, StreamExt, TryFutureExt};
use helpers::NonDetachingJoinHandle;
use imbl_value::{InOMap, InternedString};
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncSeek, AsyncWrite};
use tokio::sync::{mpsc, watch};
use tokio::sync::watch;
use ts_rs::TS;
use crate::db::model::{Database, DatabaseModel};
@@ -168,39 +170,23 @@ impl FullProgress {
}
}
#[derive(Clone)]
pub struct FullProgressTracker {
overall: Arc<watch::Sender<Progress>>,
overall_recv: watch::Receiver<Progress>,
phases: InOMap<InternedString, watch::Receiver<Progress>>,
new_phase: (
mpsc::UnboundedSender<(InternedString, watch::Receiver<Progress>)>,
mpsc::UnboundedReceiver<(InternedString, watch::Receiver<Progress>)>,
),
overall: watch::Sender<Progress>,
phases: watch::Sender<InOMap<InternedString, watch::Receiver<Progress>>>,
}
impl FullProgressTracker {
pub fn new() -> Self {
let (overall, overall_recv) = watch::channel(Progress::new());
Self {
overall: Arc::new(overall),
overall_recv,
phases: InOMap::new(),
new_phase: mpsc::unbounded_channel(),
}
let (overall, _) = watch::channel(Progress::new());
let (phases, _) = watch::channel(InOMap::new());
Self { overall, phases }
}
fn fill_phases(&mut self) -> bool {
let mut changed = false;
while let Ok((name, phase)) = self.new_phase.1.try_recv() {
self.phases.insert(name, phase);
changed = true;
}
changed
}
pub fn snapshot(&mut self) -> FullProgress {
self.fill_phases();
pub fn snapshot(&self) -> FullProgress {
FullProgress {
overall: *self.overall.borrow(),
phases: self
.phases
.borrow()
.iter()
.map(|(name, progress)| NamedProgress {
name: name.clone(),
@@ -209,28 +195,75 @@ impl FullProgressTracker {
.collect(),
}
}
pub async fn changed(&mut self) {
if self.fill_phases() {
return;
}
let phases = self
.phases
.iter_mut()
.map(|(_, p)| Box::pin(p.changed()))
.collect_vec();
tokio::select! {
_ = self.overall_recv.changed() => (),
_ = futures::future::select_all(phases) => (),
}
}
pub fn handle(&self) -> FullProgressTrackerHandle {
FullProgressTrackerHandle {
overall: self.overall.clone(),
new_phase: self.new_phase.0.clone(),
pub fn stream(&self, min_interval: Option<Duration>) -> BoxStream<'static, FullProgress> {
struct StreamState {
overall: watch::Receiver<Progress>,
phases_recv: watch::Receiver<InOMap<InternedString, watch::Receiver<Progress>>>,
phases: InOMap<InternedString, watch::Receiver<Progress>>,
}
let mut overall = self.overall.subscribe();
overall.mark_changed(); // make sure stream starts with a value
let phases_recv = self.phases.subscribe();
let phases = phases_recv.borrow().clone();
let state = StreamState {
overall,
phases_recv,
phases,
};
futures::stream::unfold(
state,
move |StreamState {
mut overall,
mut phases_recv,
mut phases,
}| async move {
let changed = phases
.iter_mut()
.map(|(_, p)| async move { p.changed().or_else(|_| pending()).await }.boxed())
.chain([overall.changed().boxed()])
.chain([phases_recv.changed().boxed()])
.map(|fut| fut.map(|r| r.unwrap_or_default()))
.collect_vec();
if let Some(min_interval) = min_interval {
tokio::join!(
tokio::time::sleep(min_interval),
futures::future::select_all(changed),
);
} else {
futures::future::select_all(changed).await;
}
for (name, phase) in &*phases_recv.borrow_and_update() {
if !phases.contains_key(name) {
phases.insert(name.clone(), phase.clone());
}
}
let o = *overall.borrow_and_update();
Some((
FullProgress {
overall: o,
phases: phases
.iter_mut()
.map(|(name, progress)| NamedProgress {
name: name.clone(),
progress: *progress.borrow_and_update(),
})
.collect(),
},
StreamState {
overall,
phases_recv,
phases,
},
))
},
)
.boxed()
}
pub fn sync_to_db<DerefFn>(
mut self,
&self,
db: TypedPatchDb<Database>,
deref: DerefFn,
min_interval: Option<Duration>,
@@ -239,9 +272,9 @@ impl FullProgressTracker {
DerefFn: Fn(&mut DatabaseModel) -> Option<&mut Model<FullProgress>> + 'static,
for<'a> &'a DerefFn: UnwindSafe + Send,
{
let mut stream = self.stream(min_interval);
async move {
loop {
let progress = self.snapshot();
while let Some(progress) = stream.next().await {
if db
.mutate(|v| {
if let Some(p) = deref(v) {
@@ -255,25 +288,23 @@ impl FullProgressTracker {
{
break;
}
tokio::join!(self.changed(), async {
if let Some(interval) = min_interval {
tokio::time::sleep(interval).await
} else {
futures::future::ready(()).await
}
});
}
Ok(())
}
}
}
#[derive(Clone)]
pub struct FullProgressTrackerHandle {
overall: Arc<watch::Sender<Progress>>,
new_phase: mpsc::UnboundedSender<(InternedString, watch::Receiver<Progress>)>,
}
impl FullProgressTrackerHandle {
pub fn progress_bar_task(&self, name: &str) -> NonDetachingJoinHandle<()> {
let mut stream = self.stream(None);
let mut bar = PhasedProgressBar::new(name);
tokio::spawn(async move {
while let Some(progress) = stream.next().await {
bar.update(&progress);
if progress.overall.is_complete() {
break;
}
}
})
.into()
}
pub fn add_phase(
&self,
name: InternedString,
@@ -284,7 +315,9 @@ impl FullProgressTrackerHandle {
.send_modify(|o| o.add_total(overall_contribution));
}
let (send, recv) = watch::channel(Progress::new());
let _ = self.new_phase.send((name, recv));
self.phases.send_modify(|p| {
p.insert(name, recv);
});
PhaseProgressTrackerHandle {
overall: self.overall.clone(),
overall_contribution,
@@ -298,7 +331,7 @@ impl FullProgressTrackerHandle {
}
pub struct PhaseProgressTrackerHandle {
overall: Arc<watch::Sender<Progress>>,
overall: watch::Sender<Progress>,
overall_contribution: Option<u64>,
contributed: u64,
progress: watch::Sender<Progress>,

View File

@@ -138,7 +138,6 @@ impl Middleware<RegistryContext> for Auth {
if request.headers().contains_key(AUTH_SIG_HEADER) {
self.signer = Some(
async {
let request = request;
let SignatureHeader {
commitment,
signer,

View File

@@ -134,7 +134,7 @@ pub struct HardwareInfo {
impl From<&RpcContext> for HardwareInfo {
fn from(value: &RpcContext) -> Self {
Self {
arch: InternedString::intern(&**crate::ARCH),
arch: InternedString::intern(crate::ARCH),
ram: value.hardware.ram,
devices: value
.hardware

View File

@@ -70,7 +70,7 @@ pub fn registry_api<C: Context>() -> ParentHandler<C> {
.subcommand("db", db::db_api::<C>())
}
pub fn registry_server_router(ctx: RegistryContext) -> Router {
pub fn registry_router(ctx: RegistryContext) -> Router {
use axum::extract as x;
use axum::routing::{any, get, post};
Router::new()
@@ -128,7 +128,7 @@ pub fn registry_server_router(ctx: RegistryContext) -> Router {
}
impl WebServer {
pub fn registry(bind: SocketAddr, ctx: RegistryContext) -> Self {
Self::new(bind, registry_server_router(ctx))
pub fn serve_registry(&mut self, ctx: RegistryContext) {
self.serve_router(registry_router(ctx))
}
}

View File

@@ -186,29 +186,16 @@ pub async fn cli_add_asset(
let file = MultiCursorFile::from(tokio::fs::File::open(&path).await?);
let mut progress = FullProgressTracker::new();
let progress_handle = progress.handle();
let mut sign_phase =
progress_handle.add_phase(InternedString::intern("Signing File"), Some(10));
let mut verify_phase =
progress_handle.add_phase(InternedString::intern("Verifying URL"), Some(100));
let mut index_phase = progress_handle.add_phase(
let progress = FullProgressTracker::new();
let mut sign_phase = progress.add_phase(InternedString::intern("Signing File"), Some(10));
let mut verify_phase = progress.add_phase(InternedString::intern("Verifying URL"), Some(100));
let mut index_phase = progress.add_phase(
InternedString::intern("Adding File to Registry Index"),
Some(1),
);
let progress_task: NonDetachingJoinHandle<()> = tokio::spawn(async move {
let mut bar = PhasedProgressBar::new(&format!("Adding {} to registry...", path.display()));
loop {
let snap = progress.snapshot();
bar.update(&snap);
if snap.overall.is_complete() {
break;
}
progress.changed().await
}
})
.into();
let progress_task =
progress.progress_bar_task(&format!("Adding {} to registry...", path.display()));
sign_phase.start();
let blake3 = file.blake3_mmap().await?;
@@ -252,7 +239,7 @@ pub async fn cli_add_asset(
.await?;
index_phase.complete();
progress_handle.complete();
progress.complete();
progress_task.await.with_kind(ErrorKind::Unknown)?;

View File

@@ -3,7 +3,7 @@ use std::panic::UnwindSafe;
use std::path::{Path, PathBuf};
use clap::Parser;
use helpers::{AtomicFile, NonDetachingJoinHandle};
use helpers::AtomicFile;
use imbl_value::{json, InternedString};
use itertools::Itertools;
use rpc_toolkit::{from_fn_async, Context, HandlerArgs, HandlerExt, ParentHandler};
@@ -12,7 +12,7 @@ use ts_rs::TS;
use crate::context::CliContext;
use crate::prelude::*;
use crate::progress::{FullProgressTracker, PhasedProgressBar};
use crate::progress::FullProgressTracker;
use crate::registry::asset::RegistryAsset;
use crate::registry::context::RegistryContext;
use crate::registry::os::index::OsVersionInfo;
@@ -135,29 +135,17 @@ async fn cli_get_os_asset(
.await
.with_kind(ErrorKind::Filesystem)?;
let mut progress = FullProgressTracker::new();
let progress_handle = progress.handle();
let progress = FullProgressTracker::new();
let mut download_phase =
progress_handle.add_phase(InternedString::intern("Downloading File"), Some(100));
progress.add_phase(InternedString::intern("Downloading File"), Some(100));
download_phase.set_total(res.commitment.size);
let reverify_phase = if reverify {
Some(progress_handle.add_phase(InternedString::intern("Reverifying File"), Some(10)))
Some(progress.add_phase(InternedString::intern("Reverifying File"), Some(10)))
} else {
None
};
let progress_task: NonDetachingJoinHandle<()> = tokio::spawn(async move {
let mut bar = PhasedProgressBar::new("Downloading...");
loop {
let snap = progress.snapshot();
bar.update(&snap);
if snap.overall.is_complete() {
break;
}
progress.changed().await
}
})
.into();
let progress_task = progress.progress_bar_task("Downloading...");
download_phase.start();
let mut download_writer = download_phase.writer(&mut *file);
@@ -177,7 +165,7 @@ async fn cli_get_os_asset(
reverify_phase.complete();
}
progress_handle.complete();
progress.complete();
progress_task.await.with_kind(ErrorKind::Unknown)?;
}

View File

@@ -3,7 +3,6 @@ use std::panic::UnwindSafe;
use std::path::PathBuf;
use clap::Parser;
use helpers::NonDetachingJoinHandle;
use imbl_value::InternedString;
use itertools::Itertools;
use rpc_toolkit::{from_fn_async, Context, HandlerArgs, HandlerExt, ParentHandler};
@@ -12,7 +11,7 @@ use ts_rs::TS;
use crate::context::CliContext;
use crate::prelude::*;
use crate::progress::{FullProgressTracker, PhasedProgressBar};
use crate::progress::FullProgressTracker;
use crate::registry::asset::RegistryAsset;
use crate::registry::context::RegistryContext;
use crate::registry::os::index::OsVersionInfo;
@@ -169,27 +168,15 @@ pub async fn cli_sign_asset(
let file = MultiCursorFile::from(tokio::fs::File::open(&path).await?);
let mut progress = FullProgressTracker::new();
let progress_handle = progress.handle();
let mut sign_phase =
progress_handle.add_phase(InternedString::intern("Signing File"), Some(10));
let mut index_phase = progress_handle.add_phase(
let progress = FullProgressTracker::new();
let mut sign_phase = progress.add_phase(InternedString::intern("Signing File"), Some(10));
let mut index_phase = progress.add_phase(
InternedString::intern("Adding Signature to Registry Index"),
Some(1),
);
let progress_task: NonDetachingJoinHandle<()> = tokio::spawn(async move {
let mut bar = PhasedProgressBar::new(&format!("Adding {} to registry...", path.display()));
loop {
let snap = progress.snapshot();
bar.update(&snap);
if snap.overall.is_complete() {
break;
}
progress.changed().await
}
})
.into();
let progress_task =
progress.progress_bar_task(&format!("Adding {} to registry...", path.display()));
sign_phase.start();
let blake3 = file.blake3_mmap().await?;
@@ -220,7 +207,7 @@ pub async fn cli_sign_asset(
.await?;
index_phase.complete();
progress_handle.complete();
progress.complete();
progress_task.await.with_kind(ErrorKind::Unknown)?;

View File

@@ -2,7 +2,6 @@ use std::path::PathBuf;
use std::sync::Arc;
use clap::Parser;
use helpers::NonDetachingJoinHandle;
use imbl_value::InternedString;
use itertools::Itertools;
use rpc_toolkit::HandlerArgs;
@@ -12,7 +11,7 @@ use url::Url;
use crate::context::CliContext;
use crate::prelude::*;
use crate::progress::{FullProgressTracker, PhasedProgressBar};
use crate::progress::FullProgressTracker;
use crate::registry::context::RegistryContext;
use crate::registry::package::index::PackageVersionInfo;
use crate::registry::signer::commitment::merkle_archive::MerkleArchiveCommitment;
@@ -53,7 +52,6 @@ pub async fn add_package(
let s9pk = S9pk::deserialize(
&Arc::new(HttpSource::new(ctx.client.clone(), url.clone()).await?),
Some(&commitment),
false,
)
.await?;
@@ -109,30 +107,18 @@ pub async fn cli_add_package(
..
}: HandlerArgs<CliContext, CliAddPackageParams>,
) -> Result<(), Error> {
let s9pk = S9pk::open(&file, None, false).await?;
let s9pk = S9pk::open(&file, None).await?;
let mut progress = FullProgressTracker::new();
let progress_handle = progress.handle();
let mut sign_phase = progress_handle.add_phase(InternedString::intern("Signing File"), Some(1));
let mut verify_phase =
progress_handle.add_phase(InternedString::intern("Verifying URL"), Some(100));
let mut index_phase = progress_handle.add_phase(
let progress = FullProgressTracker::new();
let mut sign_phase = progress.add_phase(InternedString::intern("Signing File"), Some(1));
let mut verify_phase = progress.add_phase(InternedString::intern("Verifying URL"), Some(100));
let mut index_phase = progress.add_phase(
InternedString::intern("Adding File to Registry Index"),
Some(1),
);
let progress_task: NonDetachingJoinHandle<()> = tokio::spawn(async move {
let mut bar = PhasedProgressBar::new(&format!("Adding {} to registry...", file.display()));
loop {
let snap = progress.snapshot();
bar.update(&snap);
if snap.overall.is_complete() {
break;
}
progress.changed().await
}
})
.into();
let progress_task =
progress.progress_bar_task(&format!("Adding {} to registry...", file.display()));
sign_phase.start();
let commitment = s9pk.as_archive().commitment().await?;
@@ -143,7 +129,6 @@ pub async fn cli_add_package(
let mut src = S9pk::deserialize(
&Arc::new(HttpSource::new(ctx.client.clone(), url.clone()).await?),
Some(&commitment),
false,
)
.await?;
src.serialize(&mut TrackingIO::new(0, tokio::io::sink()), true)
@@ -162,7 +147,7 @@ pub async fn cli_add_package(
.await?;
index_phase.complete();
progress_handle.complete();
progress.complete();
progress_task.await.with_kind(ErrorKind::Unknown)?;

View File

@@ -1,5 +1,5 @@
use std::time::{SystemTime, UNIX_EPOCH};
use std::collections::BTreeMap;
use std::time::{SystemTime, UNIX_EPOCH};
use axum::body::Body;
use axum::extract::Request;

View File

@@ -1,5 +1,8 @@
use std::collections::BTreeMap;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Mutex as SyncMutex;
use std::task::{Context, Poll};
use std::time::Duration;
use axum::extract::ws::WebSocket;
@@ -7,9 +10,10 @@ use axum::extract::Request;
use axum::response::Response;
use clap::builder::ValueParserFactory;
use futures::future::BoxFuture;
use futures::{Future, FutureExt};
use helpers::TimedResource;
use imbl_value::InternedString;
use tokio::sync::Mutex;
use tokio::sync::{broadcast, Mutex as AsyncMutex};
use ts_rs::TS;
#[allow(unused_imports)]
@@ -39,6 +43,11 @@ impl Guid {
Some(Guid(InternedString::intern(r)))
}
}
impl Default for Guid {
fn default() -> Self {
Self::new()
}
}
impl AsRef<str> for Guid {
fn as_ref(&self) -> &str {
self.0.as_ref()
@@ -68,21 +77,103 @@ impl std::fmt::Display for Guid {
}
}
pub type RestHandler =
Box<dyn FnOnce(Request) -> BoxFuture<'static, Result<Response, crate::Error>> + Send>;
pub struct RestFuture {
kill: Option<broadcast::Receiver<()>>,
fut: BoxFuture<'static, Result<Response, Error>>,
}
impl Future for RestFuture {
type Output = Result<Response, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.kill.as_ref().map_or(false, |k| !k.is_empty()) {
Poll::Ready(Err(Error::new(
eyre!("session killed"),
ErrorKind::Authorization,
)))
} else {
self.fut.poll_unpin(cx)
}
}
}
pub type RestHandler = Box<dyn FnOnce(Request) -> RestFuture + Send>;
pub type WebSocketHandler = Box<dyn FnOnce(WebSocket) -> BoxFuture<'static, ()> + Send>;
pub struct WebSocketFuture {
kill: Option<broadcast::Receiver<()>>,
fut: BoxFuture<'static, ()>,
}
impl Future for WebSocketFuture {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.kill.as_ref().map_or(false, |k| !k.is_empty()) {
Poll::Ready(())
} else {
self.fut.poll_unpin(cx)
}
}
}
pub type WebSocketHandler = Box<dyn FnOnce(WebSocket) -> WebSocketFuture + Send>;
pub enum RpcContinuation {
Rest(TimedResource<RestHandler>),
WebSocket(TimedResource<WebSocketHandler>),
}
impl RpcContinuation {
pub fn rest(handler: RestHandler, timeout: Duration) -> Self {
RpcContinuation::Rest(TimedResource::new(handler, timeout))
pub fn rest<F, Fut>(handler: F, timeout: Duration) -> Self
where
F: FnOnce(Request) -> Fut + Send + 'static,
Fut: Future<Output = Result<Response, Error>> + Send + 'static,
{
RpcContinuation::Rest(TimedResource::new(
Box::new(|req| RestFuture {
kill: None,
fut: handler(req).boxed(),
}),
timeout,
))
}
pub fn ws(handler: WebSocketHandler, timeout: Duration) -> Self {
RpcContinuation::WebSocket(TimedResource::new(handler, timeout))
pub fn ws<F, Fut>(handler: F, timeout: Duration) -> Self
where
F: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
RpcContinuation::WebSocket(TimedResource::new(
Box::new(|ws| WebSocketFuture {
kill: None,
fut: handler(ws).boxed(),
}),
timeout,
))
}
pub fn rest_authed<Ctx, T, F, Fut>(ctx: Ctx, session: T, handler: F, timeout: Duration) -> Self
where
Ctx: AsRef<OpenAuthedContinuations<T>>,
T: Eq + Ord,
F: FnOnce(Request) -> Fut + Send + 'static,
Fut: Future<Output = Result<Response, Error>> + Send + 'static,
{
let kill = Some(ctx.as_ref().subscribe_to_kill(session));
RpcContinuation::Rest(TimedResource::new(
Box::new(|req| RestFuture {
kill,
fut: handler(req).boxed(),
}),
timeout,
))
}
pub fn ws_authed<Ctx, T, F, Fut>(ctx: Ctx, session: T, handler: F, timeout: Duration) -> Self
where
Ctx: AsRef<OpenAuthedContinuations<T>>,
T: Eq + Ord,
F: FnOnce(WebSocket) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let kill = Some(ctx.as_ref().subscribe_to_kill(session));
RpcContinuation::WebSocket(TimedResource::new(
Box::new(|ws| WebSocketFuture {
kill,
fut: handler(ws).boxed(),
}),
timeout,
))
}
pub fn is_timed_out(&self) -> bool {
match self {
@@ -92,10 +183,10 @@ impl RpcContinuation {
}
}
pub struct RpcContinuations(Mutex<BTreeMap<Guid, RpcContinuation>>);
pub struct RpcContinuations(AsyncMutex<BTreeMap<Guid, RpcContinuation>>);
impl RpcContinuations {
pub fn new() -> Self {
RpcContinuations(Mutex::new(BTreeMap::new()))
RpcContinuations(AsyncMutex::new(BTreeMap::new()))
}
#[instrument(skip_all)]
@@ -141,3 +232,28 @@ impl RpcContinuations {
x.get().await
}
}
pub struct OpenAuthedContinuations<Key: Eq + Ord>(SyncMutex<BTreeMap<Key, broadcast::Sender<()>>>);
impl<T> OpenAuthedContinuations<T>
where
T: Eq + Ord,
{
pub fn new() -> Self {
Self(SyncMutex::new(BTreeMap::new()))
}
pub fn kill(&self, session: &T) {
if let Some(channel) = self.0.lock().unwrap().remove(session) {
channel.send(()).ok();
}
}
fn subscribe_to_kill(&self, session: T) -> broadcast::Receiver<()> {
let mut map = self.0.lock().unwrap();
if let Some(send) = map.get(&session) {
send.subscribe()
} else {
let (send, recv) = broadcast::channel(1);
map.insert(session, send);
recv
}
}
}

View File

@@ -211,7 +211,10 @@ impl<S: FileSource + Clone> DirectoryContents<S> {
if !filter(path) {
if v.hash.is_none() {
return Err(Error::new(
eyre!("cannot filter out unhashed file, run `update_hashes` first"),
eyre!(
"cannot filter out unhashed file {}, run `update_hashes` first",
path.display()
),
ErrorKind::InvalidRequest,
));
}

View File

@@ -0,0 +1,103 @@
use std::ffi::OsStr;
use std::path::Path;
use crate::prelude::*;
use crate::s9pk::merkle_archive::directory_contents::DirectoryContents;
use crate::s9pk::merkle_archive::source::FileSource;
use crate::s9pk::merkle_archive::Entry;
/// An object for tracking the files expected to be in an s9pk
pub struct Expected<'a, T> {
keep: DirectoryContents<()>,
dir: &'a DirectoryContents<T>,
}
impl<'a, T> Expected<'a, T> {
pub fn new(dir: &'a DirectoryContents<T>,) -> Self {
Self {
keep: DirectoryContents::new(),
dir
}
}
}
impl<'a, T: Clone> Expected<'a, T> {
pub fn check_file(&mut self, path: impl AsRef<Path>) -> Result<(), Error> {
if self
.dir
.get_path(path.as_ref())
.and_then(|e| e.as_file())
.is_some()
{
self.keep.insert_path(path, Entry::file(()))?;
Ok(())
} else {
Err(Error::new(
eyre!("file {} missing from archive", path.as_ref().display()),
ErrorKind::ParseS9pk,
))
}
}
pub fn check_stem(
&mut self,
path: impl AsRef<Path>,
mut valid_extension: impl FnMut(Option<&OsStr>) -> bool,
) -> Result<(), Error> {
let (dir, stem) = if let Some(parent) = path.as_ref().parent().filter(|p| *p != Path::new("")) {
(
self.dir
.get_path(parent)
.and_then(|e| e.as_directory())
.ok_or_else(|| {
Error::new(
eyre!("directory {} missing from archive", parent.display()),
ErrorKind::ParseS9pk,
)
})?,
path.as_ref().strip_prefix(parent).unwrap(),
)
} else {
(self.dir, path.as_ref())
};
let name = dir
.with_stem(&stem.as_os_str().to_string_lossy())
.filter(|(_, e)| e.as_file().is_some())
.try_fold(
Err(Error::new(
eyre!(
"file {} with valid extension missing from archive",
path.as_ref().display()
),
ErrorKind::ParseS9pk,
)),
|acc, (name, _)|
if valid_extension(Path::new(&*name).extension()) {
match acc {
Ok(_) => Err(Error::new(
eyre!(
"more than one file matching {} with valid extension in archive",
path.as_ref().display()
),
ErrorKind::ParseS9pk,
)),
Err(_) => Ok(Ok(name))
}
} else {
Ok(acc)
}
)??;
self.keep
.insert_path(path.as_ref().with_file_name(name), Entry::file(()))?;
Ok(())
}
pub fn into_filter(self) -> Filter {
Filter(self.keep)
}
}
pub struct Filter(DirectoryContents<()>);
impl Filter {
pub fn keep_checked<T: FileSource + Clone>(&self, dir: &mut DirectoryContents<T>) -> Result<(), Error> {
dir.filter(|path| self.0.get_path(path).is_some())
}
}

View File

@@ -19,6 +19,7 @@ use crate::util::serde::Base64;
use crate::CAP_1_MiB;
pub mod directory_contents;
pub mod expected;
pub mod file_contents;
pub mod hash;
pub mod sink;
@@ -217,6 +218,9 @@ impl<S> Entry<S> {
pub fn file(source: S) -> Self {
Self::new(EntryContents::File(FileContents::new(source)))
}
pub fn directory(directory: DirectoryContents<S>) -> Self {
Self::new(EntryContents::Directory(directory))
}
pub fn hash(&self) -> Option<(Hash, u64)> {
self.hash
}

View File

@@ -280,3 +280,8 @@ impl<S: ArchiveSource> FileSource for Section<S> {
self.source.copy_to(self.position, self.size, w).await
}
}
pub type DynRead = Box<dyn AsyncRead + Unpin + Send + Sync + 'static>;
pub fn into_dyn_read<R: AsyncRead + Unpin + Send + Sync + 'static>(r: R) -> DynRead {
Box::new(r)
}

View File

@@ -97,7 +97,8 @@ impl ArchiveSource for MultiCursorFile {
.ok()
.map(|m| m.len())
}
async fn fetch_all(&self) -> Result<impl AsyncRead + Unpin + Send, Error> {
#[allow(refining_impl_trait)]
async fn fetch_all(&self) -> Result<impl AsyncRead + Unpin + Send + 'static, Error> {
use tokio::io::AsyncSeekExt;
let mut file = self.cursor().await?;

View File

@@ -1,32 +1,26 @@
use std::collections::BTreeSet;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::path::PathBuf;
use clap::Parser;
use itertools::Itertools;
use models::ImageId;
use rpc_toolkit::{from_fn_async, Empty, HandlerExt, ParentHandler};
use serde::{Deserialize, Serialize};
use tokio::fs::File;
use tokio::process::Command;
use ts_rs::TS;
use crate::context::CliContext;
use crate::prelude::*;
use crate::s9pk::manifest::Manifest;
use crate::s9pk::merkle_archive::source::DynFileSource;
use crate::s9pk::merkle_archive::Entry;
use crate::s9pk::v2::compat::CONTAINER_TOOL;
use crate::s9pk::v2::pack::ImageConfig;
use crate::s9pk::v2::SIG_CONTEXT;
use crate::s9pk::S9pk;
use crate::util::io::TmpDir;
use crate::util::serde::{apply_expr, HandlerExtSerde};
use crate::util::Invoke;
pub const SKIP_ENV: &[&str] = &["TERM", "container", "HOME", "HOSTNAME"];
pub fn s9pk() -> ParentHandler<CliContext> {
ParentHandler::new()
.subcommand("pack", from_fn_async(super::v2::pack::pack).no_display())
.subcommand("edit", edit())
.subcommand("inspect", inspect())
}
@@ -77,117 +71,21 @@ fn inspect() -> ParentHandler<CliContext, S9pkPath> {
#[derive(Deserialize, Serialize, Parser, TS)]
struct AddImageParams {
id: ImageId,
image: String,
arches: Option<Vec<String>>,
#[command(flatten)]
config: ImageConfig,
}
async fn add_image(
ctx: CliContext,
AddImageParams { id, image, arches }: AddImageParams,
AddImageParams { id, config }: AddImageParams,
S9pkPath { s9pk: s9pk_path }: S9pkPath,
) -> Result<(), Error> {
let mut s9pk = S9pk::from_file(super::load(&ctx, &s9pk_path).await?, false)
let mut s9pk = S9pk::from_file(super::load(&ctx, &s9pk_path).await?)
.await?
.into_dyn();
let arches: BTreeSet<_> = arches
.unwrap_or_else(|| vec!["x86_64".to_owned(), "aarch64".to_owned()])
.into_iter()
.collect();
s9pk.as_manifest_mut().images.insert(id, config);
let tmpdir = TmpDir::new().await?;
for arch in arches {
let sqfs_path = tmpdir.join(format!("image.{arch}.squashfs"));
let docker_platform = if arch == "x86_64" {
"--platform=linux/amd64".to_owned()
} else if arch == "aarch64" {
"--platform=linux/arm64".to_owned()
} else {
format!("--platform=linux/{arch}")
};
let env = String::from_utf8(
Command::new(CONTAINER_TOOL)
.arg("run")
.arg("--rm")
.arg(&docker_platform)
.arg("--entrypoint")
.arg("env")
.arg(&image)
.invoke(ErrorKind::Docker)
.await?,
)?
.lines()
.filter(|l| {
l.trim()
.split_once("=")
.map_or(false, |(v, _)| !SKIP_ENV.contains(&v))
})
.join("\n")
+ "\n";
let workdir = Path::new(
String::from_utf8(
Command::new(CONTAINER_TOOL)
.arg("run")
.arg(&docker_platform)
.arg("--rm")
.arg("--entrypoint")
.arg("pwd")
.arg(&image)
.invoke(ErrorKind::Docker)
.await?,
)?
.trim(),
)
.to_owned();
let container_id = String::from_utf8(
Command::new(CONTAINER_TOOL)
.arg("create")
.arg(&docker_platform)
.arg(&image)
.invoke(ErrorKind::Docker)
.await?,
)?;
Command::new("bash")
.arg("-c")
.arg(format!(
"{CONTAINER_TOOL} export {container_id} | mksquashfs - {sqfs} -tar",
container_id = container_id.trim(),
sqfs = sqfs_path.display()
))
.invoke(ErrorKind::Docker)
.await?;
Command::new(CONTAINER_TOOL)
.arg("rm")
.arg(container_id.trim())
.invoke(ErrorKind::Docker)
.await?;
let archive = s9pk.as_archive_mut();
archive.set_signer(ctx.developer_key()?.clone(), SIG_CONTEXT);
archive.contents_mut().insert_path(
Path::new("images")
.join(&arch)
.join(&id)
.with_extension("squashfs"),
Entry::file(DynFileSource::new(sqfs_path)),
)?;
archive.contents_mut().insert_path(
Path::new("images")
.join(&arch)
.join(&id)
.with_extension("env"),
Entry::file(DynFileSource::new(Arc::<[u8]>::from(Vec::from(env)))),
)?;
archive.contents_mut().insert_path(
Path::new("images")
.join(&arch)
.join(&id)
.with_extension("json"),
Entry::file(DynFileSource::new(Arc::<[u8]>::from(
serde_json::to_vec(&serde_json::json!({
"workdir": workdir
}))
.with_kind(ErrorKind::Serialization)?,
))),
)?;
}
s9pk.as_manifest_mut().images.insert(id);
s9pk.load_images(&tmpdir).await?;
s9pk.validate_and_filter(None)?;
let tmp_path = s9pk_path.with_extension("s9pk.tmp");
let mut tmp_file = File::create(&tmp_path).await?;
s9pk.serialize(&mut tmp_file, true).await?;
@@ -206,7 +104,7 @@ async fn edit_manifest(
EditManifestParams { expression }: EditManifestParams,
S9pkPath { s9pk: s9pk_path }: S9pkPath,
) -> Result<Manifest, Error> {
let mut s9pk = S9pk::from_file(super::load(&ctx, &s9pk_path).await?, false).await?;
let mut s9pk = S9pk::from_file(super::load(&ctx, &s9pk_path).await?).await?;
let old = serde_json::to_value(s9pk.as_manifest()).with_kind(ErrorKind::Serialization)?;
*s9pk.as_manifest_mut() = serde_json::from_value(apply_expr(old.into(), &expression)?.into())
.with_kind(ErrorKind::Serialization)?;
@@ -227,7 +125,7 @@ async fn file_tree(
_: Empty,
S9pkPath { s9pk }: S9pkPath,
) -> Result<Vec<PathBuf>, Error> {
let s9pk = S9pk::from_file(super::load(&ctx, &s9pk).await?, false).await?;
let s9pk = S9pk::from_file(super::load(&ctx, &s9pk).await?).await?;
Ok(s9pk.as_archive().contents().file_paths(""))
}
@@ -244,7 +142,7 @@ async fn cat(
) -> Result<(), Error> {
use crate::s9pk::merkle_archive::source::FileSource;
let s9pk = S9pk::from_file(super::load(&ctx, &s9pk).await?, false).await?;
let s9pk = S9pk::from_file(super::load(&ctx, &s9pk).await?).await?;
tokio::io::copy(
&mut s9pk
.as_archive()
@@ -266,6 +164,6 @@ async fn inspect_manifest(
_: Empty,
S9pkPath { s9pk }: S9pkPath,
) -> Result<Manifest, Error> {
let s9pk = S9pk::from_file(super::load(&ctx, &s9pk).await?, false).await?;
let s9pk = S9pk::from_file(super::load(&ctx, &s9pk).await?).await?;
Ok(s9pk.as_manifest().clone())
}

View File

@@ -1,6 +1,5 @@
use std::collections::{BTreeMap, BTreeSet};
use std::io::Cursor;
use std::path::{Path, PathBuf};
use std::collections::BTreeMap;
use std::path::Path;
use std::sync::Arc;
use itertools::Itertools;
@@ -14,49 +13,18 @@ use crate::prelude::*;
use crate::s9pk::manifest::Manifest;
use crate::s9pk::merkle_archive::directory_contents::DirectoryContents;
use crate::s9pk::merkle_archive::source::multi_cursor_file::MultiCursorFile;
use crate::s9pk::merkle_archive::source::{FileSource, Section};
use crate::s9pk::merkle_archive::source::Section;
use crate::s9pk::merkle_archive::{Entry, MerkleArchive};
use crate::s9pk::rpc::SKIP_ENV;
use crate::s9pk::v1::manifest::{Manifest as ManifestV1, PackageProcedure};
use crate::s9pk::v1::reader::S9pkReader;
use crate::s9pk::v2::pack::{PackSource, CONTAINER_TOOL};
use crate::s9pk::v2::{S9pk, SIG_CONTEXT};
use crate::util::io::TmpDir;
use crate::util::Invoke;
pub const MAGIC_AND_VERSION: &[u8] = &[0x3b, 0x3b, 0x01];
#[cfg(not(feature = "docker"))]
pub const CONTAINER_TOOL: &str = "podman";
#[cfg(feature = "docker")]
pub const CONTAINER_TOOL: &str = "docker";
type DynRead = Box<dyn AsyncRead + Unpin + Send + Sync + 'static>;
fn into_dyn_read<R: AsyncRead + Unpin + Send + Sync + 'static>(r: R) -> DynRead {
Box::new(r)
}
#[derive(Clone)]
enum CompatSource {
Buffered(Arc<[u8]>),
File(PathBuf),
}
impl FileSource for CompatSource {
type Reader = Box<dyn AsyncRead + Unpin + Send + Sync + 'static>;
async fn size(&self) -> Result<u64, Error> {
match self {
Self::Buffered(a) => Ok(a.len() as u64),
Self::File(f) => Ok(tokio::fs::metadata(f).await?.len()),
}
}
async fn reader(&self) -> Result<Self::Reader, Error> {
match self {
Self::Buffered(a) => Ok(into_dyn_read(Cursor::new(a.clone()))),
Self::File(f) => Ok(into_dyn_read(File::open(f).await?)),
}
}
}
impl S9pk<Section<MultiCursorFile>> {
#[instrument(skip_all)]
pub async fn from_v1<R: AsyncRead + AsyncSeek + Unpin + Send + Sync>(
@@ -66,7 +34,7 @@ impl S9pk<Section<MultiCursorFile>> {
) -> Result<Self, Error> {
let scratch_dir = TmpDir::new().await?;
let mut archive = DirectoryContents::<CompatSource>::new();
let mut archive = DirectoryContents::<PackSource>::new();
// manifest.json
let manifest_raw = reader.manifest().await?;
@@ -88,21 +56,21 @@ impl S9pk<Section<MultiCursorFile>> {
let license: Arc<[u8]> = reader.license().await?.to_vec().await?.into();
archive.insert_path(
"LICENSE.md",
Entry::file(CompatSource::Buffered(license.into())),
Entry::file(PackSource::Buffered(license.into())),
)?;
// instructions.md
let instructions: Arc<[u8]> = reader.instructions().await?.to_vec().await?.into();
archive.insert_path(
"instructions.md",
Entry::file(CompatSource::Buffered(instructions.into())),
Entry::file(PackSource::Buffered(instructions.into())),
)?;
// icon.md
let icon: Arc<[u8]> = reader.icon().await?.to_vec().await?.into();
archive.insert_path(
format!("icon.{}", manifest.assets.icon_type()),
Entry::file(CompatSource::Buffered(icon.into())),
Entry::file(PackSource::Buffered(icon.into())),
)?;
// images
@@ -122,7 +90,9 @@ impl S9pk<Section<MultiCursorFile>> {
.invoke(ErrorKind::Docker)
.await?;
for (image, system) in &images {
new_manifest.images.insert(image.clone());
let mut image_config = new_manifest.images.remove(image).unwrap_or_default();
image_config.arch.insert(arch.as_str().into());
new_manifest.images.insert(image.clone(), image_config);
let sqfs_path = images_dir.join(image).with_extension("squashfs");
let image_name = if *system {
format!("start9/{}:latest", image)
@@ -190,21 +160,21 @@ impl S9pk<Section<MultiCursorFile>> {
.join(&arch)
.join(&image)
.with_extension("squashfs"),
Entry::file(CompatSource::File(sqfs_path)),
Entry::file(PackSource::File(sqfs_path)),
)?;
archive.insert_path(
Path::new("images")
.join(&arch)
.join(&image)
.with_extension("env"),
Entry::file(CompatSource::Buffered(Vec::from(env).into())),
Entry::file(PackSource::Buffered(Vec::from(env).into())),
)?;
archive.insert_path(
Path::new("images")
.join(&arch)
.join(&image)
.with_extension("json"),
Entry::file(CompatSource::Buffered(
Entry::file(PackSource::Buffered(
serde_json::to_vec(&serde_json::json!({
"workdir": workdir
}))
@@ -239,8 +209,10 @@ impl S9pk<Section<MultiCursorFile>> {
.invoke(ErrorKind::Filesystem)
.await?;
archive.insert_path(
Path::new("assets").join(&asset_id),
Entry::file(CompatSource::File(sqfs_path)),
Path::new("assets")
.join(&asset_id)
.with_extension("squashfs"),
Entry::file(PackSource::File(sqfs_path)),
)?;
}
@@ -267,12 +239,12 @@ impl S9pk<Section<MultiCursorFile>> {
.await?;
archive.insert_path(
Path::new("javascript.squashfs"),
Entry::file(CompatSource::File(sqfs_path)),
Entry::file(PackSource::File(sqfs_path)),
)?;
archive.insert_path(
"manifest.json",
Entry::file(CompatSource::Buffered(
Entry::file(PackSource::Buffered(
serde_json::to_vec::<Manifest>(&new_manifest)
.with_kind(ErrorKind::Serialization)?
.into(),
@@ -289,7 +261,6 @@ impl S9pk<Section<MultiCursorFile>> {
Ok(S9pk::deserialize(
&MultiCursorFile::from(File::open(destination.as_ref()).await?),
None,
false,
)
.await?)
}
@@ -310,7 +281,7 @@ impl From<ManifestV1> for Manifest {
marketing_site: value.marketing_site.unwrap_or_else(|| default_url.clone()),
donation_url: value.donation_url,
description: value.description,
images: BTreeSet::new(),
images: BTreeMap::new(),
assets: value
.volumes
.iter()

View File

@@ -1,10 +1,11 @@
use std::collections::{BTreeMap, BTreeSet};
use std::path::Path;
use color_eyre::eyre::eyre;
use helpers::const_true;
use imbl_value::InternedString;
pub use models::PackageId;
use models::{ImageId, VolumeId};
use models::{mime, ImageId, VolumeId};
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use url::Url;
@@ -12,6 +13,9 @@ use url::Url;
use crate::dependencies::Dependencies;
use crate::prelude::*;
use crate::s9pk::git_hash::GitHash;
use crate::s9pk::merkle_archive::directory_contents::DirectoryContents;
use crate::s9pk::merkle_archive::expected::{Expected, Filter};
use crate::s9pk::v2::pack::ImageConfig;
use crate::util::serde::Regex;
use crate::util::VersionString;
use crate::version::{Current, VersionT};
@@ -42,7 +46,7 @@ pub struct Manifest {
#[ts(type = "string | null")]
pub donation_url: Option<Url>,
pub description: Description,
pub images: BTreeSet<ImageId>,
pub images: BTreeMap<ImageId, ImageConfig>,
pub assets: BTreeSet<VolumeId>, // TODO: AssetsId
pub volumes: BTreeSet<VolumeId>,
#[serde(default)]
@@ -59,13 +63,90 @@ pub struct Manifest {
#[serde(default = "const_true")]
pub has_config: bool,
}
impl Manifest {
pub fn validate_for<'a, T: Clone>(
&self,
arch: Option<&str>,
archive: &'a DirectoryContents<T>,
) -> Result<Filter, Error> {
let mut expected = Expected::new(archive);
expected.check_file("manifest.json")?;
expected.check_stem("icon", |ext| {
ext.and_then(|e| e.to_str())
.and_then(mime)
.map_or(false, |mime| mime.starts_with("image/"))
})?;
expected.check_file("LICENSE.md")?;
expected.check_file("instructions.md")?;
expected.check_file("javascript.squashfs")?;
for assets in &self.assets {
expected.check_file(Path::new("assets").join(assets).with_extension("squashfs"))?;
}
for (image_id, config) in &self.images {
let mut check_arch = |arch: &str| {
let mut arch = arch;
if let Err(e) = expected.check_file(
Path::new("images")
.join(arch)
.join(image_id)
.with_extension("squashfs"),
) {
if let Some(emulate_as) = &config.emulate_missing_as {
expected.check_file(
Path::new("images")
.join(arch)
.join(image_id)
.with_extension("squashfs"),
)?;
arch = &**emulate_as;
} else {
return Err(e);
}
}
expected.check_file(
Path::new("images")
.join(arch)
.join(image_id)
.with_extension("json"),
)?;
expected.check_file(
Path::new("images")
.join(arch)
.join(image_id)
.with_extension("env"),
)?;
Ok(())
};
if let Some(arch) = arch {
check_arch(arch)?;
} else if let Some(arches) = &self.hardware_requirements.arch {
for arch in arches {
check_arch(arch)?;
}
} else if let Some(arch) = config.emulate_missing_as.as_deref() {
if !config.arch.contains(arch) {
return Err(Error::new(
eyre!("`emulateMissingAs` must match an included `arch`"),
ErrorKind::ParseS9pk,
));
}
for arch in &config.arch {
check_arch(&arch)?;
}
} else {
return Err(Error::new(eyre!("`emulateMissingAs` required for all images if no `arch` specified in `hardwareRequirements`"), ErrorKind::ParseS9pk));
}
}
Ok(expected.into_filter())
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct HardwareRequirements {
#[serde(default)]
#[ts(type = "{ [key: string]: string }")]
#[ts(type = "{ [key: string]: string }")] // TODO more specific key
pub device: BTreeMap<String, Regex>,
#[ts(type = "number | null")]
pub ram: Option<u64>,

View File

@@ -14,7 +14,8 @@ use crate::s9pk::merkle_archive::sink::Sink;
use crate::s9pk::merkle_archive::source::multi_cursor_file::MultiCursorFile;
use crate::s9pk::merkle_archive::source::{ArchiveSource, DynFileSource, FileSource, Section};
use crate::s9pk::merkle_archive::{Entry, MerkleArchive};
use crate::ARCH;
use crate::s9pk::v2::pack::{ImageSource, PackSource};
use crate::util::io::TmpDir;
const MAGIC_AND_VERSION: &[u8] = &[0x3b, 0x3b, 0x02];
@@ -22,6 +23,7 @@ pub const SIG_CONTEXT: &str = "s9pk";
pub mod compat;
pub mod manifest;
pub mod pack;
/**
/
@@ -34,10 +36,14 @@ pub mod manifest;
│ └── <id>.squashfs (xN)
└── images
└── <arch>
├── <id>.json (xN)
├── <id>.env (xN)
└── <id>.squashfs (xN)
*/
// this sorts the s9pk to optimize such that the parts that are used first appear earlier in the s9pk
// this is useful for manipulating an s9pk while partially downloaded on a source that does not support
// random access
fn priority(s: &str) -> Option<usize> {
match s {
"manifest.json" => Some(0),
@@ -51,26 +57,6 @@ fn priority(s: &str) -> Option<usize> {
}
}
fn filter(p: &Path) -> bool {
match p.iter().count() {
1 if p.file_name() == Some(OsStr::new("manifest.json")) => true,
1 if p.file_stem() == Some(OsStr::new("icon")) => true,
1 if p.file_name() == Some(OsStr::new("LICENSE.md")) => true,
1 if p.file_name() == Some(OsStr::new("instructions.md")) => true,
1 if p.file_name() == Some(OsStr::new("javascript.squashfs")) => true,
1 if p.file_name() == Some(OsStr::new("assets")) => true,
1 if p.file_name() == Some(OsStr::new("images")) => true,
2 if p.parent() == Some(Path::new("assets")) => {
p.extension().map_or(false, |ext| ext == "squashfs")
}
2 if p.parent() == Some(Path::new("images")) => p.file_name() == Some(OsStr::new(&*ARCH)),
3 if p.parent() == Some(&*Path::new("images").join(&*ARCH)) => p
.extension()
.map_or(false, |ext| ext == "squashfs" || ext == "env"),
_ => false,
}
}
#[derive(Clone)]
pub struct S9pk<S = Section<MultiCursorFile>> {
pub manifest: Manifest,
@@ -108,6 +94,11 @@ impl<S: FileSource + Clone> S9pk<S> {
})
}
pub fn validate_and_filter(&mut self, arch: Option<&str>) -> Result<(), Error> {
let filter = self.manifest.validate_for(arch, self.archive.contents())?;
filter.keep_checked(self.archive.contents_mut())
}
pub async fn icon(&self) -> Result<(InternedString, FileContents<S>), Error> {
let mut best_icon = None;
for (path, icon) in self
@@ -174,12 +165,37 @@ impl<S: FileSource + Clone> S9pk<S> {
}
}
impl<S: From<PackSource> + FileSource + Clone> S9pk<S> {
pub async fn load_images(&mut self, tmpdir: &TmpDir) -> Result<(), Error> {
let id = &self.manifest.id;
let version = &self.manifest.version;
for (image_id, image_config) in &mut self.manifest.images {
self.manifest_dirty = true;
for arch in &image_config.arch {
image_config
.source
.load(
tmpdir,
id,
version,
image_id,
arch,
self.archive.contents_mut(),
)
.await?;
}
image_config.source = ImageSource::Packed;
}
Ok(())
}
}
impl<S: ArchiveSource + Clone> S9pk<Section<S>> {
#[instrument(skip_all)]
pub async fn deserialize(
source: &S,
commitment: Option<&MerkleArchiveCommitment>,
apply_filter: bool,
) -> Result<Self, Error> {
use tokio::io::AsyncReadExt;
@@ -201,10 +217,6 @@ impl<S: ArchiveSource + Clone> S9pk<Section<S>> {
let mut archive =
MerkleArchive::deserialize(source, SIG_CONTEXT, &mut header, commitment).await?;
if apply_filter {
archive.filter(filter)?;
}
archive.sort_by(|a, b| match (priority(a), priority(b)) {
(Some(a), Some(b)) => a.cmp(&b),
(Some(_), None) => std::cmp::Ordering::Less,
@@ -216,15 +228,11 @@ impl<S: ArchiveSource + Clone> S9pk<Section<S>> {
}
}
impl S9pk {
pub async fn from_file(file: File, apply_filter: bool) -> Result<Self, Error> {
Self::deserialize(&MultiCursorFile::from(file), None, apply_filter).await
pub async fn from_file(file: File) -> Result<Self, Error> {
Self::deserialize(&MultiCursorFile::from(file), None).await
}
pub async fn open(
path: impl AsRef<Path>,
id: Option<&PackageId>,
apply_filter: bool,
) -> Result<Self, Error> {
let res = Self::from_file(tokio::fs::File::open(path).await?, apply_filter).await?;
pub async fn open(path: impl AsRef<Path>, id: Option<&PackageId>) -> Result<Self, Error> {
let res = Self::from_file(tokio::fs::File::open(path).await?).await?;
if let Some(id) = id {
ensure_code!(
&res.as_manifest().id == id,

View File

@@ -0,0 +1,536 @@
use std::collections::BTreeSet;
use std::ffi::OsStr;
use std::io::Cursor;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use clap::Parser;
use futures::future::{ready, BoxFuture};
use futures::{FutureExt, TryStreamExt};
use imbl_value::InternedString;
use models::{ImageId, PackageId, VersionString};
use serde::{Deserialize, Serialize};
use tokio::fs::File;
use tokio::io::AsyncRead;
use tokio::process::Command;
use tokio::sync::OnceCell;
use tokio_stream::wrappers::ReadDirStream;
use ts_rs::TS;
use crate::context::CliContext;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::s9pk::merkle_archive::directory_contents::DirectoryContents;
use crate::s9pk::merkle_archive::source::multi_cursor_file::MultiCursorFile;
use crate::s9pk::merkle_archive::source::{
into_dyn_read, ArchiveSource, DynFileSource, FileSource,
};
use crate::s9pk::merkle_archive::{Entry, MerkleArchive};
use crate::s9pk::v2::SIG_CONTEXT;
use crate::s9pk::S9pk;
use crate::util::io::TmpDir;
use crate::util::Invoke;
#[cfg(not(feature = "docker"))]
pub const CONTAINER_TOOL: &str = "podman";
#[cfg(feature = "docker")]
pub const CONTAINER_TOOL: &str = "docker";
pub struct SqfsDir {
path: PathBuf,
tmpdir: Arc<TmpDir>,
sqfs: OnceCell<MultiCursorFile>,
}
impl SqfsDir {
pub fn new(path: PathBuf, tmpdir: Arc<TmpDir>) -> Self {
Self {
path,
tmpdir,
sqfs: OnceCell::new(),
}
}
async fn file(&self) -> Result<&MultiCursorFile, Error> {
self.sqfs
.get_or_try_init(|| async move {
let guid = Guid::new();
let path = self.tmpdir.join(guid.as_ref()).with_extension("squashfs");
let mut cmd = Command::new("mksquashfs");
if self.path.extension().and_then(|s| s.to_str()) == Some("tar") {
cmd.arg("-tar");
}
cmd.arg(&self.path)
.arg(&path)
.invoke(ErrorKind::Filesystem)
.await?;
Ok(MultiCursorFile::from(
File::open(&path)
.await
.with_ctx(|_| (ErrorKind::Filesystem, path.display()))?,
))
})
.await
}
}
#[derive(Clone)]
pub enum PackSource {
Buffered(Arc<[u8]>),
File(PathBuf),
Squashfs(Arc<SqfsDir>),
}
impl FileSource for PackSource {
type Reader = Box<dyn AsyncRead + Unpin + Send + Sync + 'static>;
async fn size(&self) -> Result<u64, Error> {
match self {
Self::Buffered(a) => Ok(a.len() as u64),
Self::File(f) => Ok(tokio::fs::metadata(f)
.await
.with_ctx(|_| (ErrorKind::Filesystem, f.display()))?
.len()),
Self::Squashfs(dir) => dir
.file()
.await
.with_ctx(|_| (ErrorKind::Filesystem, dir.path.display()))?
.size()
.await
.or_not_found("file metadata"),
}
}
async fn reader(&self) -> Result<Self::Reader, Error> {
match self {
Self::Buffered(a) => Ok(into_dyn_read(Cursor::new(a.clone()))),
Self::File(f) => Ok(into_dyn_read(
File::open(f)
.await
.with_ctx(|_| (ErrorKind::Filesystem, f.display()))?,
)),
Self::Squashfs(dir) => dir.file().await?.fetch_all().await.map(into_dyn_read),
}
}
}
impl From<PackSource> for DynFileSource {
fn from(value: PackSource) -> Self {
DynFileSource::new(value)
}
}
#[derive(Deserialize, Serialize, Parser)]
pub struct PackParams {
pub path: Option<PathBuf>,
#[arg(short = 'o', long = "output")]
pub output: Option<PathBuf>,
#[arg(long = "javascript")]
pub javascript: Option<PathBuf>,
#[arg(long = "icon")]
pub icon: Option<PathBuf>,
#[arg(long = "license")]
pub license: Option<PathBuf>,
#[arg(long = "instructions")]
pub instructions: Option<PathBuf>,
#[arg(long = "assets")]
pub assets: Option<PathBuf>,
}
impl PackParams {
fn path(&self) -> &Path {
self.path.as_deref().unwrap_or(Path::new("."))
}
fn output(&self, id: &PackageId) -> PathBuf {
self.output
.as_ref()
.cloned()
.unwrap_or_else(|| self.path().join(id).with_extension("s9pk"))
}
fn javascript(&self) -> PathBuf {
self.javascript
.as_ref()
.cloned()
.unwrap_or_else(|| self.path().join("javascript"))
}
async fn icon(&self) -> Result<PathBuf, Error> {
if let Some(icon) = &self.icon {
Ok(icon.clone())
} else {
ReadDirStream::new(tokio::fs::read_dir(self.path()).await?).try_filter(|x| ready(x.path().file_stem() == Some(OsStr::new("icon")))).map_err(Error::from).try_fold(Err(Error::new(eyre!("icon not found"), ErrorKind::NotFound)), |acc, x| async move { match acc {
Ok(_) => Err(Error::new(eyre!("multiple icons found in working directory, please specify which to use with `--icon`"), ErrorKind::InvalidRequest)),
Err(e) => Ok({
let path = x.path();
if path.file_stem().and_then(|s| s.to_str()) == Some("icon") {
Ok(path)
} else {
Err(e)
}
})
}}).await?
}
}
fn license(&self) -> PathBuf {
self.license
.as_ref()
.cloned()
.unwrap_or_else(|| self.path().join("LICENSE.md"))
}
fn instructions(&self) -> PathBuf {
self.instructions
.as_ref()
.cloned()
.unwrap_or_else(|| self.path().join("instructions.md"))
}
fn assets(&self) -> PathBuf {
self.assets
.as_ref()
.cloned()
.unwrap_or_else(|| self.path().join("assets"))
}
}
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct ImageConfig {
pub source: ImageSource,
#[ts(type = "string[]")]
pub arch: BTreeSet<InternedString>,
#[ts(type = "string | null")]
pub emulate_missing_as: Option<InternedString>,
}
impl Default for ImageConfig {
fn default() -> Self {
Self {
source: ImageSource::Packed,
arch: BTreeSet::new(),
emulate_missing_as: None,
}
}
}
#[derive(Parser)]
struct CliImageConfig {
#[arg(long, conflicts_with("docker-tag"))]
docker_build: bool,
#[arg(long, requires("docker-build"))]
dockerfile: Option<PathBuf>,
#[arg(long, requires("docker-build"))]
workdir: Option<PathBuf>,
#[arg(long, conflicts_with_all(["dockerfile", "workdir"]))]
docker_tag: Option<String>,
#[arg(long)]
arch: Vec<InternedString>,
#[arg(long)]
emulate_missing_as: Option<InternedString>,
}
impl TryFrom<CliImageConfig> for ImageConfig {
type Error = clap::Error;
fn try_from(value: CliImageConfig) -> Result<Self, Self::Error> {
let res = Self {
source: if value.docker_build {
ImageSource::DockerBuild {
dockerfile: value.dockerfile,
workdir: value.workdir,
}
} else if let Some(tag) = value.docker_tag {
ImageSource::DockerTag(tag)
} else {
ImageSource::Packed
},
arch: value.arch.into_iter().collect(),
emulate_missing_as: value.emulate_missing_as,
};
res.emulate_missing_as
.as_ref()
.map(|a| {
if !res.arch.contains(a) {
Err(clap::Error::raw(
clap::error::ErrorKind::InvalidValue,
"`emulate-missing-as` must match one of the provided `arch`es",
))
} else {
Ok(())
}
})
.transpose()?;
Ok(res)
}
}
impl clap::Args for ImageConfig {
fn augment_args(cmd: clap::Command) -> clap::Command {
CliImageConfig::augment_args(cmd)
}
fn augment_args_for_update(cmd: clap::Command) -> clap::Command {
CliImageConfig::augment_args_for_update(cmd)
}
}
impl clap::FromArgMatches for ImageConfig {
fn from_arg_matches(matches: &clap::ArgMatches) -> Result<Self, clap::Error> {
Self::try_from(CliImageConfig::from_arg_matches(matches)?)
}
fn update_from_arg_matches(&mut self, matches: &clap::ArgMatches) -> Result<(), clap::Error> {
*self = Self::try_from(CliImageConfig::from_arg_matches(matches)?)?;
Ok(())
}
}
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub enum ImageSource {
Packed,
#[serde(rename_all = "camelCase")]
DockerBuild {
workdir: Option<PathBuf>,
dockerfile: Option<PathBuf>,
},
DockerTag(String),
}
impl ImageSource {
#[instrument(skip_all)]
pub fn load<'a, S: From<PackSource> + FileSource + Clone>(
&'a self,
tmpdir: &'a TmpDir,
id: &'a PackageId,
version: &'a VersionString,
image_id: &'a ImageId,
arch: &'a str,
into: &'a mut DirectoryContents<S>,
) -> BoxFuture<'a, Result<(), Error>> {
#[derive(Deserialize)]
#[serde(rename_all = "PascalCase")]
struct DockerImageConfig {
env: Vec<String>,
#[serde(default)]
working_dir: PathBuf,
#[serde(default)]
user: String,
}
async move {
match self {
ImageSource::Packed => Ok(()),
ImageSource::DockerBuild {
workdir,
dockerfile,
} => {
let workdir = workdir.as_deref().unwrap_or(Path::new("."));
let dockerfile = dockerfile
.clone()
.unwrap_or_else(|| workdir.join("Dockerfile"));
let docker_platform = if arch == "x86_64" {
"--platform=linux/amd64".to_owned()
} else if arch == "aarch64" {
"--platform=linux/arm64".to_owned()
} else {
format!("--platform=linux/{arch}")
};
// docker buildx build ${path} -o type=image,name=start9/${id}
let tag = format!("start9/{id}/{image_id}:{version}");
Command::new(CONTAINER_TOOL)
.arg("build")
.arg(workdir)
.arg("-f")
.arg(dockerfile)
.arg("-t")
.arg(&tag)
.arg(&docker_platform)
.arg("-o")
.arg("type=image")
.capture(false)
.invoke(ErrorKind::Docker)
.await?;
ImageSource::DockerTag(tag.clone())
.load(tmpdir, id, version, image_id, arch, into)
.await?;
Command::new(CONTAINER_TOOL)
.arg("rmi")
.arg("-f")
.arg(&tag)
.invoke(ErrorKind::Docker)
.await?;
Ok(())
}
ImageSource::DockerTag(tag) => {
let docker_platform = if arch == "x86_64" {
"--platform=linux/amd64".to_owned()
} else if arch == "aarch64" {
"--platform=linux/arm64".to_owned()
} else {
format!("--platform=linux/{arch}")
};
let mut inspect_cmd = Command::new(CONTAINER_TOOL);
inspect_cmd
.arg("image")
.arg("inspect")
.arg("--format")
.arg("{{json .Config}}")
.arg(&tag);
let inspect_res = match inspect_cmd.invoke(ErrorKind::Docker).await {
Ok(a) => a,
Err(e)
if {
let msg = e.source.to_string();
#[cfg(feature = "docker")]
let matches = msg.contains("No such image:");
#[cfg(not(feature = "docker"))]
let matches = msg.contains(": image not known");
matches
} =>
{
Command::new(CONTAINER_TOOL)
.arg("pull")
.arg(&docker_platform)
.arg(tag)
.capture(false)
.invoke(ErrorKind::Docker)
.await?;
inspect_cmd.invoke(ErrorKind::Docker).await?
}
Err(e) => return Err(e),
};
let config = serde_json::from_slice::<DockerImageConfig>(&inspect_res)
.with_kind(ErrorKind::Deserialization)?;
let base_path = Path::new("images").join(arch).join(image_id);
into.insert_path(
base_path.with_extension("json"),
Entry::file(
PackSource::Buffered(
serde_json::to_vec(&ImageMetadata {
workdir: if config.working_dir == Path::new("") {
"/".into()
} else {
config.working_dir
},
user: if config.user.is_empty() {
"root".into()
} else {
config.user.into()
},
})
.with_kind(ErrorKind::Serialization)?
.into(),
)
.into(),
),
)?;
into.insert_path(
base_path.with_extension("env"),
Entry::file(
PackSource::Buffered(config.env.join("\n").into_bytes().into()).into(),
),
)?;
let dest = tmpdir.join(Guid::new().as_ref()).with_extension("squashfs");
let container = String::from_utf8(
Command::new(CONTAINER_TOOL)
.arg("create")
.arg(&docker_platform)
.arg(&tag)
.invoke(ErrorKind::Docker)
.await?,
)?;
Command::new(CONTAINER_TOOL)
.arg("export")
.arg(container.trim())
.pipe(Command::new("mksquashfs").arg("-").arg(&dest).arg("-tar"))
.capture(false)
.invoke(ErrorKind::Docker)
.await?;
Command::new(CONTAINER_TOOL)
.arg("rm")
.arg(container.trim())
.invoke(ErrorKind::Docker)
.await?;
into.insert_path(
base_path.with_extension("squashfs"),
Entry::file(PackSource::File(dest).into()),
)?;
Ok(())
}
}
}
.boxed()
}
}
#[derive(Debug, Clone, Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct ImageMetadata {
pub workdir: PathBuf,
#[ts(type = "string")]
pub user: InternedString,
}
#[instrument(skip_all)]
pub async fn pack(ctx: CliContext, params: PackParams) -> Result<(), Error> {
let tmpdir = Arc::new(TmpDir::new().await?);
let mut files = DirectoryContents::<PackSource>::new();
let js_dir = params.javascript();
let manifest: Arc<[u8]> = Command::new("node")
.arg("-e")
.arg(format!(
"console.log(JSON.stringify(require('{}/index.js').manifest))",
js_dir.display()
))
.invoke(ErrorKind::Javascript)
.await?
.into();
files.insert(
"manifest.json".into(),
Entry::file(PackSource::Buffered(manifest.clone())),
);
let icon = params.icon().await?;
let icon_ext = icon
.extension()
.or_not_found("icon file extension")?
.to_string_lossy();
files.insert(
InternedString::from_display(&lazy_format!("icon.{}", icon_ext)),
Entry::file(PackSource::File(icon)),
);
files.insert(
"LICENSE.md".into(),
Entry::file(PackSource::File(params.license())),
);
files.insert(
"instructions.md".into(),
Entry::file(PackSource::File(params.instructions())),
);
files.insert(
"javascript.squashfs".into(),
Entry::file(PackSource::Squashfs(Arc::new(SqfsDir::new(
js_dir,
tmpdir.clone(),
)))),
);
let mut s9pk = S9pk::new(
MerkleArchive::new(files, ctx.developer_key()?.clone(), SIG_CONTEXT),
None,
)
.await?;
let assets_dir = params.assets();
for assets in s9pk.as_manifest().assets.clone() {
s9pk.as_archive_mut().contents_mut().insert_path(
Path::new("assets").join(&assets).with_extension("squashfs"),
Entry::file(PackSource::Squashfs(Arc::new(SqfsDir::new(
assets_dir.join(&assets),
tmpdir.clone(),
)))),
)?;
}
s9pk.load_images(&*tmpdir).await?;
s9pk.validate_and_filter(None)?;
s9pk.serialize(
&mut File::create(params.output(&s9pk.as_manifest().id)).await?,
false,
)
.await?;
drop(s9pk);
tmpdir.gc().await?;
Ok(())
}

View File

@@ -4,6 +4,7 @@ use models::{ActionId, ProcedureName};
use crate::action::ActionResult;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::service::config::GetConfig;
use crate::service::dependencies::DependencyConfig;
use crate::service::{Service, ServiceActor};
@@ -23,13 +24,18 @@ impl Handler<Action> for ServiceActor {
}
async fn handle(
&mut self,
Action { id, input }: Action,
id: Guid,
Action {
id: action_id,
input,
}: Action,
_: &BackgroundJobQueue,
) -> Self::Response {
let container = &self.0.persistent_container;
container
.execute::<ActionResult>(
ProcedureName::RunAction(id),
id,
ProcedureName::RunAction(action_id),
input,
Some(Duration::from_secs(30)),
)
@@ -39,7 +45,20 @@ impl Handler<Action> for ServiceActor {
}
impl Service {
pub async fn action(&self, id: ActionId, input: Value) -> Result<ActionResult, Error> {
self.actor.send(Action { id, input }).await?
pub async fn action(
&self,
id: Guid,
action_id: ActionId,
input: Value,
) -> Result<ActionResult, Error> {
self.actor
.send(
id,
Action {
id: action_id,
input,
},
)
.await?
}
}

View File

@@ -5,6 +5,7 @@ use models::ProcedureName;
use crate::config::action::ConfigRes;
use crate::config::ConfigureContext;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::service::dependencies::DependencyConfig;
use crate::service::{Service, ServiceActor};
use crate::util::actor::background::BackgroundJobQueue;
@@ -19,6 +20,7 @@ impl Handler<Configure> for ServiceActor {
}
async fn handle(
&mut self,
id: Guid,
Configure(ConfigureContext { timeout, config }): Configure,
_: &BackgroundJobQueue,
) -> Self::Response {
@@ -26,7 +28,7 @@ impl Handler<Configure> for ServiceActor {
let package_id = &self.0.id;
container
.execute::<NoOutput>(ProcedureName::SetConfig, to_value(&config)?, timeout)
.execute::<NoOutput>(id, ProcedureName::SetConfig, to_value(&config)?, timeout)
.await
.with_kind(ErrorKind::ConfigRulesViolation)?;
self.0
@@ -52,10 +54,11 @@ impl Handler<GetConfig> for ServiceActor {
fn conflicts_with(_: &GetConfig) -> ConflictBuilder<Self> {
ConflictBuilder::nothing().except::<Configure>()
}
async fn handle(&mut self, _: GetConfig, _: &BackgroundJobQueue) -> Self::Response {
async fn handle(&mut self, id: Guid, _: GetConfig, _: &BackgroundJobQueue) -> Self::Response {
let container = &self.0.persistent_container;
container
.execute::<ConfigRes>(
id,
ProcedureName::GetConfig,
Value::Null,
Some(Duration::from_secs(30)), // TODO timeout
@@ -66,10 +69,10 @@ impl Handler<GetConfig> for ServiceActor {
}
impl Service {
pub async fn configure(&self, ctx: ConfigureContext) -> Result<(), Error> {
self.actor.send(Configure(ctx)).await?
pub async fn configure(&self, id: Guid, ctx: ConfigureContext) -> Result<(), Error> {
self.actor.send(id, Configure(ctx)).await?
}
pub async fn get_config(&self) -> Result<ConfigRes, Error> {
self.actor.send(GetConfig).await?
pub async fn get_config(&self, id: Guid) -> Result<ConfigRes, Error> {
self.actor.send(id, GetConfig).await?
}
}

View File

@@ -1,4 +1,5 @@
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::service::config::GetConfig;
use crate::service::dependencies::DependencyConfig;
use crate::service::start_stop::StartStop;
@@ -15,7 +16,7 @@ impl Handler<Start> for ServiceActor {
.except::<GetConfig>()
.except::<DependencyConfig>()
}
async fn handle(&mut self, _: Start, _: &BackgroundJobQueue) -> Self::Response {
async fn handle(&mut self, _: Guid, _: Start, _: &BackgroundJobQueue) -> Self::Response {
self.0.persistent_container.state.send_modify(|x| {
x.desired_state = StartStop::Start;
});
@@ -23,8 +24,8 @@ impl Handler<Start> for ServiceActor {
}
}
impl Service {
pub async fn start(&self) -> Result<(), Error> {
self.actor.send(Start).await
pub async fn start(&self, id: Guid) -> Result<(), Error> {
self.actor.send(id, Start).await
}
}
@@ -36,7 +37,7 @@ impl Handler<Stop> for ServiceActor {
.except::<GetConfig>()
.except::<DependencyConfig>()
}
async fn handle(&mut self, _: Stop, _: &BackgroundJobQueue) -> Self::Response {
async fn handle(&mut self, _: Guid, _: Stop, _: &BackgroundJobQueue) -> Self::Response {
let mut transition_state = None;
self.0.persistent_container.state.send_modify(|x| {
x.desired_state = StartStop::Stop;
@@ -51,7 +52,7 @@ impl Handler<Stop> for ServiceActor {
}
}
impl Service {
pub async fn stop(&self) -> Result<(), Error> {
self.actor.send(Stop).await
pub async fn stop(&self, id: Guid) -> Result<(), Error> {
self.actor.send(id, Stop).await
}
}

View File

@@ -4,35 +4,28 @@ use imbl_value::json;
use models::{PackageId, ProcedureName};
use crate::prelude::*;
use crate::service::{Service, ServiceActor};
use crate::rpc_continuations::Guid;
use crate::service::{Service, ServiceActor, ServiceActorSeed};
use crate::util::actor::background::BackgroundJobQueue;
use crate::util::actor::{ConflictBuilder, Handler};
use crate::Config;
pub(super) struct DependencyConfig {
dependency_id: PackageId,
remote_config: Option<Config>,
}
impl Handler<DependencyConfig> for ServiceActor {
type Response = Result<Option<Config>, Error>;
fn conflicts_with(_: &DependencyConfig) -> ConflictBuilder<Self> {
ConflictBuilder::nothing()
}
async fn handle(
&mut self,
DependencyConfig {
dependency_id,
remote_config,
}: DependencyConfig,
_: &BackgroundJobQueue,
) -> Self::Response {
let container = &self.0.persistent_container;
impl ServiceActorSeed {
async fn dependency_config(
&self,
id: Guid,
dependency_id: PackageId,
remote_config: Option<Config>,
) -> Result<Option<Config>, Error> {
let container = &self.persistent_container;
container
.sanboxed::<Option<Config>>(
id.clone(),
ProcedureName::UpdateDependency(dependency_id.clone()),
json!({
"queryResults": container
.execute::<Value>(
id,
ProcedureName::QueryDependency(dependency_id),
Value::Null,
Some(Duration::from_secs(30)),
@@ -49,17 +42,45 @@ impl Handler<DependencyConfig> for ServiceActor {
}
}
pub(super) struct DependencyConfig {
dependency_id: PackageId,
remote_config: Option<Config>,
}
impl Handler<DependencyConfig> for ServiceActor {
type Response = Result<Option<Config>, Error>;
fn conflicts_with(_: &DependencyConfig) -> ConflictBuilder<Self> {
ConflictBuilder::nothing()
}
async fn handle(
&mut self,
id: Guid,
DependencyConfig {
dependency_id,
remote_config,
}: DependencyConfig,
_: &BackgroundJobQueue,
) -> Self::Response {
self.0
.dependency_config(id, dependency_id, remote_config)
.await
}
}
impl Service {
pub async fn dependency_config(
&self,
id: Guid,
dependency_id: PackageId,
remote_config: Option<Config>,
) -> Result<Option<Config>, Error> {
self.actor
.send(DependencyConfig {
dependency_id,
remote_config,
})
.send(
id,
DependencyConfig {
dependency_id,
remote_config,
},
)
.await?
}
}

View File

@@ -1,4 +1,5 @@
use std::sync::Arc;
use std::ops::Deref;
use std::sync::{Arc, Weak};
use std::time::Duration;
use chrono::{DateTime, Utc};
@@ -10,7 +11,8 @@ use persistent_container::PersistentContainer;
use rpc_toolkit::{from_fn_async, CallRemoteHandler, Empty, HandlerArgs, HandlerFor};
use serde::{Deserialize, Serialize};
use start_stop::StartStop;
use tokio::{fs::File, sync::Notify};
use tokio::fs::File;
use tokio::sync::Notify;
use ts_rs::TS;
use crate::context::{CliContext, RpcContext};
@@ -67,13 +69,87 @@ pub enum LoadDisposition {
Undo,
}
pub struct ServiceRef(Arc<Service>);
impl ServiceRef {
pub fn weak(&self) -> Weak<Service> {
Arc::downgrade(&self.0)
}
pub async fn uninstall(
self,
target_version: Option<models::VersionString>,
) -> Result<(), Error> {
self.seed
.persistent_container
.execute(
Guid::new(),
ProcedureName::Uninit,
to_value(&target_version)?,
None,
) // TODO timeout
.await?;
let id = self.seed.persistent_container.s9pk.as_manifest().id.clone();
let ctx = self.seed.ctx.clone();
self.shutdown().await?;
if target_version.is_none() {
ctx.db
.mutate(|d| d.as_public_mut().as_package_data_mut().remove(&id))
.await?;
}
Ok(())
}
pub async fn shutdown(self) -> Result<(), Error> {
if let Some((hdl, shutdown)) = self.seed.persistent_container.rpc_server.send_replace(None)
{
self.seed
.persistent_container
.rpc_client
.request(rpc::Exit, Empty {})
.await?;
shutdown.shutdown();
hdl.await.with_kind(ErrorKind::Cancelled)?;
}
let service = Arc::try_unwrap(self.0).map_err(|_| {
Error::new(
eyre!("ServiceActor held somewhere after actor shutdown"),
ErrorKind::Unknown,
)
})?;
service
.actor
.shutdown(crate::util::actor::PendingMessageStrategy::FinishAll { timeout: None }) // TODO timeout
.await;
Arc::try_unwrap(service.seed)
.map_err(|_| {
Error::new(
eyre!("ServiceActorSeed held somewhere after actor shutdown"),
ErrorKind::Unknown,
)
})?
.persistent_container
.exit()
.await?;
Ok(())
}
}
impl Deref for ServiceRef {
type Target = Service;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
impl From<Service> for ServiceRef {
fn from(value: Service) -> Self {
Self(Arc::new(value))
}
}
pub struct Service {
actor: ConcurrentActor<ServiceActor>,
seed: Arc<ServiceActorSeed>,
}
impl Service {
#[instrument(skip_all)]
async fn new(ctx: RpcContext, s9pk: S9pk, start: StartStop) -> Result<Self, Error> {
async fn new(ctx: RpcContext, s9pk: S9pk, start: StartStop) -> Result<ServiceRef, Error> {
let id = s9pk.as_manifest().id.clone();
let persistent_container = PersistentContainer::new(
&ctx, s9pk,
@@ -88,13 +164,17 @@ impl Service {
ctx,
synchronized: Arc::new(Notify::new()),
});
seed.persistent_container
.init(Arc::downgrade(&seed))
.await?;
Ok(Self {
let service: ServiceRef = Self {
actor: ConcurrentActor::new(ServiceActor(seed.clone())),
seed,
})
}
.into();
service
.seed
.persistent_container
.init(service.weak())
.await?;
Ok(service)
}
#[instrument(skip_all)]
@@ -102,7 +182,7 @@ impl Service {
ctx: &RpcContext,
id: &PackageId,
disposition: LoadDisposition,
) -> Result<Option<Self>, Error> {
) -> Result<Option<ServiceRef>, Error> {
let handle_installed = {
let ctx = ctx.clone();
move |s9pk: S9pk, i: Model<PackageDataEntry>| async move {
@@ -136,7 +216,7 @@ impl Service {
match entry.as_state_info().as_match() {
PackageStateMatchModelRef::Installing(_) => {
if disposition == LoadDisposition::Retry {
if let Ok(s9pk) = S9pk::open(s9pk_path, Some(id), true).await.map_err(|e| {
if let Ok(s9pk) = S9pk::open(s9pk_path, Some(id)).await.map_err(|e| {
tracing::error!("Error opening s9pk for install: {e}");
tracing::debug!("{e:?}")
}) {
@@ -169,7 +249,7 @@ impl Service {
&& progress == &Progress::Complete(true)
})
{
if let Ok(s9pk) = S9pk::open(&s9pk_path, Some(id), true).await.map_err(|e| {
if let Ok(s9pk) = S9pk::open(&s9pk_path, Some(id)).await.map_err(|e| {
tracing::error!("Error opening s9pk for update: {e}");
tracing::debug!("{e:?}")
}) {
@@ -188,7 +268,7 @@ impl Service {
}
}
}
let s9pk = S9pk::open(s9pk_path, Some(id), true).await?;
let s9pk = S9pk::open(s9pk_path, Some(id)).await?;
ctx.db
.mutate({
|db| {
@@ -213,7 +293,7 @@ impl Service {
handle_installed(s9pk, entry).await
}
PackageStateMatchModelRef::Removing(_) | PackageStateMatchModelRef::Restoring(_) => {
if let Ok(s9pk) = S9pk::open(s9pk_path, Some(id), true).await.map_err(|e| {
if let Ok(s9pk) = S9pk::open(s9pk_path, Some(id)).await.map_err(|e| {
tracing::error!("Error opening s9pk for removal: {e}");
tracing::debug!("{e:?}")
}) {
@@ -224,7 +304,7 @@ impl Service {
tracing::debug!("{e:?}")
})
{
match service.uninstall(None).await {
match ServiceRef::from(service).uninstall(None).await {
Err(e) => {
tracing::error!("Error uninstalling service: {e}");
tracing::debug!("{e:?}")
@@ -241,7 +321,7 @@ impl Service {
Ok(None)
}
PackageStateMatchModelRef::Installed(_) => {
handle_installed(S9pk::open(s9pk_path, Some(id), true).await?, entry).await
handle_installed(S9pk::open(s9pk_path, Some(id)).await?, entry).await
}
PackageStateMatchModelRef::Error(e) => Err(Error::new(
eyre!("Failed to parse PackageDataEntry, found {e:?}"),
@@ -256,7 +336,7 @@ impl Service {
s9pk: S9pk,
src_version: Option<models::VersionString>,
progress: Option<InstallProgressHandles>,
) -> Result<Self, Error> {
) -> Result<ServiceRef, Error> {
let manifest = s9pk.as_manifest().clone();
let developer_key = s9pk.as_archive().signer();
let icon = s9pk.icon_data_url().await?;
@@ -264,12 +344,17 @@ impl Service {
service
.seed
.persistent_container
.execute(ProcedureName::Init, to_value(&src_version)?, None) // TODO timeout
.execute(
Guid::new(),
ProcedureName::Init,
to_value(&src_version)?,
None,
) // TODO timeout
.await
.with_kind(ErrorKind::MigrationFailed)?; // TODO: handle cancellation
if let Some(mut progress) = progress {
progress.finalization_progress.complete();
progress.progress_handle.complete();
progress.progress.complete();
tokio::task::yield_now().await;
}
ctx.db
@@ -300,61 +385,21 @@ impl Service {
s9pk: S9pk,
backup_source: impl GenericMountGuard,
progress: Option<InstallProgressHandles>,
) -> Result<Self, Error> {
) -> Result<ServiceRef, Error> {
let service = Service::install(ctx.clone(), s9pk, None, progress).await?;
service
.actor
.send(transition::restore::Restore {
path: backup_source.path().to_path_buf(),
})
.await?;
.send(
Guid::new(),
transition::restore::Restore {
path: backup_source.path().to_path_buf(),
},
)
.await??;
Ok(service)
}
pub async fn shutdown(self) -> Result<(), Error> {
self.actor
.shutdown(crate::util::actor::PendingMessageStrategy::FinishAll { timeout: None }) // TODO timeout
.await;
if let Some((hdl, shutdown)) = self.seed.persistent_container.rpc_server.send_replace(None)
{
self.seed
.persistent_container
.rpc_client
.request(rpc::Exit, Empty {})
.await?;
shutdown.shutdown();
hdl.await.with_kind(ErrorKind::Cancelled)?;
}
Arc::try_unwrap(self.seed)
.map_err(|_| {
Error::new(
eyre!("ServiceActorSeed held somewhere after actor shutdown"),
ErrorKind::Unknown,
)
})?
.persistent_container
.exit()
.await?;
Ok(())
}
pub async fn uninstall(self, target_version: Option<models::VersionString>) -> Result<(), Error> {
self.seed
.persistent_container
.execute(ProcedureName::Uninit, to_value(&target_version)?, None) // TODO timeout
.await?;
let id = self.seed.persistent_container.s9pk.as_manifest().id.clone();
let ctx = self.seed.ctx.clone();
self.shutdown().await?;
if target_version.is_none() {
ctx.db
.mutate(|d| d.as_public_mut().as_package_data_mut().remove(&id))
.await?;
}
Ok(())
}
#[instrument(skip_all)]
pub async fn backup(&self, guard: impl GenericMountGuard) -> Result<(), Error> {
let id = &self.seed.id;
@@ -367,10 +412,13 @@ impl Service {
.await?;
drop(file);
self.actor
.send(transition::backup::Backup {
path: guard.path().to_path_buf(),
})
.await?;
.send(
Guid::new(),
transition::backup::Backup {
path: guard.path().to_path_buf(),
},
)
.await??;
Ok(())
}
@@ -437,7 +485,7 @@ impl Actor for ServiceActor {
let mut current = seed.persistent_container.state.subscribe();
loop {
let kinds = dbg!(current.borrow().kinds());
let kinds = current.borrow().kinds();
if let Err(e) = async {
let main_status = match (
@@ -445,6 +493,14 @@ impl Actor for ServiceActor {
kinds.desired_state,
kinds.running_status,
) {
(Some(TransitionKind::Restarting), StartStop::Stop, Some(_)) => {
seed.persistent_container.stop().await?;
MainStatus::Restarting
}
(Some(TransitionKind::Restarting), StartStop::Start, _) => {
seed.persistent_container.start().await?;
MainStatus::Restarting
}
(Some(TransitionKind::Restarting), _, _) => MainStatus::Restarting,
(Some(TransitionKind::Restoring), _, _) => MainStatus::Restoring,
(Some(TransitionKind::BackingUp), _, Some(status)) => {
@@ -475,6 +531,30 @@ impl Actor for ServiceActor {
.mutate(|d| {
if let Some(i) = d.as_public_mut().as_package_data_mut().as_idx_mut(&id)
{
let previous = i.as_status().as_main().de()?;
let previous_health = previous.health();
let previous_started = previous.started();
let mut main_status = main_status;
match &mut main_status {
&mut MainStatus::Running { ref mut health, .. }
| &mut MainStatus::BackingUp { ref mut health, .. } => {
*health = previous_health.unwrap_or(health).clone();
}
_ => (),
};
match &mut main_status {
MainStatus::Running {
ref mut started, ..
} => {
*started = previous_started.unwrap_or(*started);
}
MainStatus::BackingUp {
ref mut started, ..
} => {
*started = previous_started.map(Some).unwrap_or(*started);
}
_ => (),
};
i.as_status_mut().as_main_mut().ser(&main_status)?;
}
Ok(())

View File

@@ -6,11 +6,10 @@ use std::time::Duration;
use futures::future::ready;
use futures::{Future, FutureExt};
use helpers::NonDetachingJoinHandle;
use imbl_value::InternedString;
use models::{ProcedureName, VolumeId};
use models::{ImageId, ProcedureName, VolumeId};
use rpc_toolkit::{Empty, Server, ShutdownHandle};
use serde::de::DeserializeOwned;
use tokio::fs::{ File};
use tokio::fs::File;
use tokio::process::Command;
use tokio::sync::{oneshot, watch, Mutex, OnceCell};
use tracing::instrument;
@@ -24,14 +23,15 @@ use crate::disk::mount::filesystem::idmapped::IdMapped;
use crate::disk::mount::filesystem::loop_dev::LoopDev;
use crate::disk::mount::filesystem::overlayfs::OverlayGuard;
use crate::disk::mount::filesystem::{MountType, ReadOnly};
use crate::disk::mount::guard::MountGuard;
use crate::disk::mount::guard::{GenericMountGuard, MountGuard};
use crate::lxc::{LxcConfig, LxcContainer, HOST_RPC_SERVER_SOCKET};
use crate::net::net_controller::NetService;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::s9pk::merkle_archive::source::FileSource;
use crate::s9pk::S9pk;
use crate::service::start_stop::StartStop;
use crate::service::{rpc, RunningStatus};
use crate::service::{rpc, RunningStatus, Service};
use crate::util::rpc_client::UnixRpcClient;
use crate::util::Invoke;
use crate::volume::{asset_dir, data_dir};
@@ -89,7 +89,8 @@ pub struct PersistentContainer {
js_mount: MountGuard,
volumes: BTreeMap<VolumeId, MountGuard>,
assets: BTreeMap<VolumeId, MountGuard>,
pub(super) overlays: Arc<Mutex<BTreeMap<InternedString, OverlayGuard>>>,
pub(super) images: BTreeMap<ImageId, Arc<MountGuard>>,
pub(super) overlays: Arc<Mutex<BTreeMap<Guid, OverlayGuard<Arc<MountGuard>>>>>,
pub(super) state: Arc<watch::Sender<ServiceState>>,
pub(super) net_service: Mutex<NetService>,
destroyed: bool,
@@ -178,14 +179,62 @@ impl PersistentContainer {
.await?,
);
}
let mut images = BTreeMap::new();
let image_path = lxc_container.rootfs_dir().join("media/startos/images");
tokio::fs::create_dir_all(&image_path).await?;
for image in &s9pk.as_manifest().images {
for (image, config) in &s9pk.as_manifest().images {
let mut arch = ARCH;
let mut sqfs_path = Path::new("images")
.join(arch)
.join(image)
.with_extension("squashfs");
if !s9pk
.as_archive()
.contents()
.get_path(&sqfs_path)
.and_then(|e| e.as_file())
.is_some()
{
arch = if let Some(arch) = config.emulate_missing_as.as_deref() {
arch
} else {
continue;
};
sqfs_path = Path::new("images")
.join(arch)
.join(image)
.with_extension("squashfs");
}
let sqfs = s9pk
.as_archive()
.contents()
.get_path(&sqfs_path)
.and_then(|e| e.as_file())
.or_not_found(sqfs_path.display())?;
let mountpoint = image_path.join(image);
tokio::fs::create_dir_all(&mountpoint).await?;
Command::new("chown")
.arg("100000:100000")
.arg(&mountpoint)
.invoke(ErrorKind::Filesystem)
.await?;
images.insert(
image.clone(),
Arc::new(
MountGuard::mount(
&IdMapped::new(LoopDev::from(&**sqfs), 0, 100000, 65536),
&mountpoint,
ReadOnly,
)
.await?,
),
);
let env_filename = Path::new(image.as_ref()).with_extension("env");
if let Some(env) = s9pk
.as_archive()
.contents()
.get_path(Path::new("images").join(*ARCH).join(&env_filename))
.get_path(Path::new("images").join(arch).join(&env_filename))
.and_then(|e| e.as_file())
{
env.copy(&mut File::create(image_path.join(&env_filename)).await?)
@@ -195,7 +244,7 @@ impl PersistentContainer {
if let Some(json) = s9pk
.as_archive()
.contents()
.get_path(Path::new("images").join(*ARCH).join(&json_filename))
.get_path(Path::new("images").join(arch).join(&json_filename))
.and_then(|e| e.as_file())
{
json.copy(&mut File::create(image_path.join(&json_filename)).await?)
@@ -215,6 +264,7 @@ impl PersistentContainer {
js_mount,
volumes,
assets,
images,
overlays: Arc::new(Mutex::new(BTreeMap::new())),
state: Arc::new(watch::channel(ServiceState::new(start)).0),
net_service: Mutex::new(net_service),
@@ -257,7 +307,7 @@ impl PersistentContainer {
}
#[instrument(skip_all)]
pub async fn init(&self, seed: Weak<ServiceActorSeed>) -> Result<(), Error> {
pub async fn init(&self, seed: Weak<Service>) -> Result<(), Error> {
let socket_server_context = EffectContext::new(seed);
let server = Server::new(
move || ready(Ok(socket_server_context.clone())),
@@ -330,6 +380,7 @@ impl PersistentContainer {
let js_mount = self.js_mount.take();
let volumes = std::mem::take(&mut self.volumes);
let assets = std::mem::take(&mut self.assets);
let images = std::mem::take(&mut self.images);
let overlays = self.overlays.clone();
let lxc_container = self.lxc_container.take();
self.destroyed = true;
@@ -352,6 +403,9 @@ impl PersistentContainer {
for (_, overlay) in std::mem::take(&mut *overlays.lock().await) {
errs.handle(overlay.unmount(true).await);
}
for (_, images) in images {
errs.handle(images.unmount().await);
}
errs.handle(js_mount.unmount(true).await);
if let Some(lxc_container) = lxc_container {
errs.handle(lxc_container.exit().await);
@@ -378,6 +432,7 @@ impl PersistentContainer {
#[instrument(skip_all)]
pub async fn start(&self) -> Result<(), Error> {
self.execute(
Guid::new(),
ProcedureName::StartMain,
Value::Null,
Some(Duration::from_secs(5)), // TODO
@@ -389,7 +444,7 @@ impl PersistentContainer {
#[instrument(skip_all)]
pub async fn stop(&self) -> Result<Duration, Error> {
let timeout: Option<crate::util::serde::Duration> = self
.execute(ProcedureName::StopMain, Value::Null, None)
.execute(Guid::new(), ProcedureName::StopMain, Value::Null, None)
.await?;
Ok(timeout.map(|a| *a).unwrap_or(Duration::from_secs(30)))
}
@@ -397,6 +452,7 @@ impl PersistentContainer {
#[instrument(skip_all)]
pub async fn execute<O>(
&self,
id: Guid,
name: ProcedureName,
input: Value,
timeout: Option<Duration>,
@@ -404,7 +460,7 @@ impl PersistentContainer {
where
O: DeserializeOwned,
{
self._execute(name, input, timeout)
self._execute(id, name, input, timeout)
.await
.and_then(from_value)
}
@@ -412,6 +468,7 @@ impl PersistentContainer {
#[instrument(skip_all)]
pub async fn sanboxed<O>(
&self,
id: Guid,
name: ProcedureName,
input: Value,
timeout: Option<Duration>,
@@ -419,7 +476,7 @@ impl PersistentContainer {
where
O: DeserializeOwned,
{
self._sandboxed(name, input, timeout)
self._sandboxed(id, name, input, timeout)
.await
.and_then(from_value)
}
@@ -427,13 +484,15 @@ impl PersistentContainer {
#[instrument(skip_all)]
async fn _execute(
&self,
id: Guid,
name: ProcedureName,
input: Value,
timeout: Option<Duration>,
) -> Result<Value, Error> {
let fut = self
.rpc_client
.request(rpc::Execute, rpc::ExecuteParams::new(name, input, timeout));
let fut = self.rpc_client.request(
rpc::Execute,
rpc::ExecuteParams::new(id, name, input, timeout),
);
Ok(if let Some(timeout) = timeout {
tokio::time::timeout(timeout, fut)
@@ -447,13 +506,15 @@ impl PersistentContainer {
#[instrument(skip_all)]
async fn _sandboxed(
&self,
id: Guid,
name: ProcedureName,
input: Value,
timeout: Option<Duration>,
) -> Result<Value, Error> {
let fut = self
.rpc_client
.request(rpc::Sandbox, rpc::ExecuteParams::new(name, input, timeout));
let fut = self.rpc_client.request(
rpc::Sandbox,
rpc::ExecuteParams::new(id, name, input, timeout),
);
Ok(if let Some(timeout) = timeout {
tokio::time::timeout(timeout, fut)

View File

@@ -3,6 +3,7 @@ use std::time::Duration;
use models::ProcedureName;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::service::Service;
impl Service {
@@ -11,6 +12,7 @@ impl Service {
let container = &self.seed.persistent_container;
container
.execute::<Value>(
Guid::new(),
ProcedureName::Properties,
Value::Null,
Some(Duration::from_secs(30)),

View File

@@ -7,6 +7,7 @@ use rpc_toolkit::Empty;
use ts_rs::TS;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
#[derive(Clone)]
pub struct Init;
@@ -46,14 +47,21 @@ impl serde::Serialize for Exit {
#[derive(Clone, serde::Deserialize, serde::Serialize, TS)]
pub struct ExecuteParams {
id: Guid,
procedure: String,
#[ts(type = "any")]
input: Value,
timeout: Option<u128>,
}
impl ExecuteParams {
pub fn new(procedure: ProcedureName, input: Value, timeout: Option<Duration>) -> Self {
pub fn new(
id: Guid,
procedure: ProcedureName,
input: Value,
timeout: Option<Duration>,
) -> Self {
Self {
id,
procedure: procedure.js_function_name(),
input,
timeout: timeout.map(|d| d.as_millis()),

View File

@@ -7,10 +7,9 @@ use std::str::FromStr;
use std::sync::{Arc, Weak};
use clap::builder::ValueParserFactory;
use clap::Parser;
use clap::{CommandFactory, FromArgMatches, Parser};
use emver::VersionRange;
use imbl::OrdMap;
use imbl_value::{json, InternedString};
use imbl_value::json;
use itertools::Itertools;
use models::{
ActionId, DataUrl, HealthCheckId, HostId, ImageId, PackageId, ServiceInterfaceId, VolumeId,
@@ -26,38 +25,34 @@ use crate::db::model::package::{
ActionMetadata, CurrentDependencies, CurrentDependencyInfo, CurrentDependencyKind,
ManifestPreference,
};
use crate::disk::mount::filesystem::idmapped::IdMapped;
use crate::disk::mount::filesystem::loop_dev::LoopDev;
use crate::disk::mount::filesystem::overlayfs::OverlayGuard;
use crate::echo;
use crate::net::host::address::HostAddress;
use crate::net::host::binding::BindOptions;
use crate::net::host::{self, HostKind};
use crate::net::service_interface::{
AddressInfo, ExportedHostInfo, ExportedHostnameInfo, ServiceInterface, ServiceInterfaceType,
ServiceInterfaceWithHostInfo,
};
use crate::net::host::binding::{BindOptions, LanInfo};
use crate::net::host::{Host, HostKind};
use crate::net::service_interface::{AddressInfo, ServiceInterface, ServiceInterfaceType};
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::s9pk::merkle_archive::source::http::HttpSource;
use crate::s9pk::rpc::SKIP_ENV;
use crate::s9pk::S9pk;
use crate::service::cli::ContainerCliContext;
use crate::service::ServiceActorSeed;
use crate::service::Service;
use crate::status::health_check::HealthCheckResult;
use crate::status::MainStatus;
use crate::util::clap::FromStrParser;
use crate::util::{new_guid, Invoke};
use crate::{echo, ARCH};
use crate::util::Invoke;
#[derive(Clone)]
pub(super) struct EffectContext(Weak<ServiceActorSeed>);
pub(super) struct EffectContext(Weak<Service>);
impl EffectContext {
pub fn new(seed: Weak<ServiceActorSeed>) -> Self {
Self(seed)
pub fn new(service: Weak<Service>) -> Self {
Self(service)
}
}
impl Context for EffectContext {}
impl EffectContext {
fn deref(&self) -> Result<Arc<ServiceActorSeed>, Error> {
fn deref(&self) -> Result<Arc<Service>, Error> {
if let Some(seed) = Weak::upgrade(&self.0) {
Ok(seed)
} else {
@@ -69,12 +64,6 @@ impl EffectContext {
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct RpcData {
id: i64,
method: String,
params: Value,
}
pub fn service_effect_handler<C: Context>() -> ParentHandler<C> {
ParentHandler::new()
.subcommand("gitInfo", from_fn(|_: C| crate::version::git_info()))
@@ -193,7 +182,6 @@ pub fn service_effect_handler<C: Context>() -> ParentHandler<C> {
.subcommand("removeAddress", from_fn_async(remove_address).no_cli())
.subcommand("exportAction", from_fn_async(export_action).no_cli())
.subcommand("removeAction", from_fn_async(remove_action).no_cli())
.subcommand("reverseProxy", from_fn_async(reverse_proxy).no_cli())
.subcommand("mount", from_fn_async(mount).no_cli())
// TODO Callbacks
@@ -233,8 +221,6 @@ struct ExportServiceInterfaceParams {
masked: bool,
address_info: AddressInfo,
r#type: ServiceInterfaceType,
host_kind: HostKind,
hostnames: Vec<ExportedHostnameInfo>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, TS)]
#[ts(export)]
@@ -242,9 +228,8 @@ struct ExportServiceInterfaceParams {
struct GetPrimaryUrlParams {
#[ts(type = "string | null")]
package_id: Option<PackageId>,
service_interface_id: String,
service_interface_id: ServiceInterfaceId,
callback: Callback,
host_id: HostId,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, TS)]
#[ts(export)]
@@ -276,37 +261,7 @@ struct RemoveActionParams {
#[ts(type = "string")]
id: ActionId,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, TS)]
#[ts(export)]
#[serde(rename_all = "camelCase")]
struct ReverseProxyBind {
ip: Option<String>,
port: u32,
ssl: bool,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, TS)]
#[ts(export)]
#[serde(rename_all = "camelCase")]
struct ReverseProxyDestination {
ip: Option<String>,
port: u32,
ssl: bool,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, TS)]
#[ts(export)]
#[serde(rename_all = "camelCase")]
struct ReverseProxyHttp {
#[ts(type = "null | {[key: string]: string}")]
headers: Option<OrdMap<String, String>>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, TS)]
#[ts(export)]
#[serde(rename_all = "camelCase")]
struct ReverseProxyParams {
bind: ReverseProxyBind,
dst: ReverseProxyDestination,
http: ReverseProxyHttp,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, TS)]
#[ts(export)]
#[serde(rename_all = "camelCase")]
@@ -328,6 +283,7 @@ struct MountParams {
async fn set_system_smtp(context: EffectContext, data: SetSystemSmtpParams) -> Result<(), Error> {
let context = context.deref()?;
context
.seed
.ctx
.db
.mutate(|db| {
@@ -342,6 +298,7 @@ async fn get_system_smtp(
) -> Result<String, Error> {
let context = context.deref()?;
let res = context
.seed
.ctx
.db
.peek()
@@ -361,24 +318,25 @@ async fn get_system_smtp(
}
async fn get_container_ip(context: EffectContext, _: Empty) -> Result<Ipv4Addr, Error> {
let context = context.deref()?;
let net_service = context.persistent_container.net_service.lock().await;
let net_service = context.seed.persistent_container.net_service.lock().await;
Ok(net_service.get_ip())
}
async fn get_service_port_forward(
context: EffectContext,
data: GetServicePortForwardParams,
) -> Result<u16, Error> {
) -> Result<LanInfo, Error> {
let internal_port = data.internal_port as u16;
let context = context.deref()?;
let net_service = context.persistent_container.net_service.lock().await;
let net_service = context.seed.persistent_container.net_service.lock().await;
net_service.get_ext_port(data.host_id, internal_port)
}
async fn clear_network_interfaces(context: EffectContext, _: Empty) -> Result<(), Error> {
let context = context.deref()?;
let package_id = context.id.clone();
let package_id = context.seed.id.clone();
context
.seed
.ctx
.db
.mutate(|db| {
@@ -404,13 +362,10 @@ async fn export_service_interface(
masked,
address_info,
r#type,
host_kind,
hostnames,
}: ExportServiceInterfaceParams,
) -> Result<(), Error> {
let context = context.deref()?;
let package_id = context.id.clone();
let host_id = address_info.host_id.clone();
let package_id = context.seed.id.clone();
let service_interface = ServiceInterface {
id: id.clone(),
@@ -422,17 +377,10 @@ async fn export_service_interface(
address_info,
interface_type: r#type,
};
let host_info = ExportedHostInfo {
id: host_id,
kind: host_kind,
hostnames,
};
let svc_interface_with_host_info = ServiceInterfaceWithHostInfo {
service_interface,
host_info,
};
let svc_interface_with_host_info = service_interface;
context
.seed
.ctx
.db
.mutate(|db| {
@@ -449,37 +397,29 @@ async fn export_service_interface(
}
async fn get_primary_url(
context: EffectContext,
data: GetPrimaryUrlParams,
) -> Result<HostAddress, Error> {
GetPrimaryUrlParams {
package_id,
service_interface_id,
callback,
}: GetPrimaryUrlParams,
) -> Result<Option<HostAddress>, Error> {
let context = context.deref()?;
let package_id = context.id.clone();
let package_id = package_id.unwrap_or_else(|| context.seed.id.clone());
let db_model = context.ctx.db.peek().await;
let pkg_data_model = db_model
.as_public()
.as_package_data()
.as_idx(&package_id)
.or_not_found(&package_id)?;
let host = pkg_data_model.de()?.hosts.get_host_primary(&data.host_id);
match host {
Some(host_address) => Ok(host_address),
None => Err(Error::new(
eyre!("Primary Url not found for {}", data.host_id),
crate::ErrorKind::NotFound,
)),
}
Ok(None) // TODO
}
async fn list_service_interfaces(
context: EffectContext,
data: ListServiceInterfacesParams,
) -> Result<BTreeMap<ServiceInterfaceId, ServiceInterfaceWithHostInfo>, Error> {
ListServiceInterfacesParams {
package_id,
callback,
}: ListServiceInterfacesParams,
) -> Result<BTreeMap<ServiceInterfaceId, ServiceInterface>, Error> {
let context = context.deref()?;
let package_id = context.id.clone();
let package_id = package_id.unwrap_or_else(|| context.seed.id.clone());
context
.seed
.ctx
.db
.peek()
@@ -493,9 +433,10 @@ async fn list_service_interfaces(
}
async fn remove_address(context: EffectContext, data: RemoveAddressParams) -> Result<(), Error> {
let context = context.deref()?;
let package_id = context.id.clone();
let package_id = context.seed.id.clone();
context
.seed
.ctx
.db
.mutate(|db| {
@@ -512,8 +453,9 @@ async fn remove_address(context: EffectContext, data: RemoveAddressParams) -> Re
}
async fn export_action(context: EffectContext, data: ExportActionParams) -> Result<(), Error> {
let context = context.deref()?;
let package_id = context.id.clone();
let package_id = context.seed.id.clone();
context
.seed
.ctx
.db
.mutate(|db| {
@@ -535,8 +477,9 @@ async fn export_action(context: EffectContext, data: ExportActionParams) -> Resu
}
async fn remove_action(context: EffectContext, data: RemoveActionParams) -> Result<(), Error> {
let context = context.deref()?;
let package_id = context.id.clone();
let package_id = context.seed.id.clone();
context
.seed
.ctx
.db
.mutate(|db| {
@@ -553,10 +496,8 @@ async fn remove_action(context: EffectContext, data: RemoveActionParams) -> Resu
.await?;
Ok(())
}
async fn reverse_proxy(context: EffectContext, data: ReverseProxyParams) -> Result<Value, Error> {
todo!()
}
async fn mount(context: EffectContext, data: MountParams) -> Result<Value, Error> {
// TODO
todo!()
}
@@ -564,49 +505,42 @@ async fn mount(context: EffectContext, data: MountParams) -> Result<Value, Error
#[ts(export)]
struct Callback(#[ts(type = "() => void")] i64);
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
enum GetHostInfoParamsKind {
Multi,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
struct GetHostInfoParams {
kind: Option<GetHostInfoParamsKind>,
service_interface_id: String,
host_id: HostId,
#[ts(type = "string | null")]
package_id: Option<PackageId>,
callback: Callback,
}
async fn get_host_info(
ctx: EffectContext,
GetHostInfoParams { .. }: GetHostInfoParams,
) -> Result<Value, Error> {
let ctx = ctx.deref()?;
Ok(json!({
"id": "fakeId1",
"kind": "multi",
"hostnames": [{
"kind": "ip",
"networkInterfaceId": "fakeNetworkInterfaceId1",
"public": true,
"hostname":{
"kind": "domain",
"domain": format!("{}", ctx.id),
"subdomain": (),
"port": (),
"sslPort": ()
}
}
context: EffectContext,
GetHostInfoParams {
callback,
package_id,
host_id,
}: GetHostInfoParams,
) -> Result<Host, Error> {
let context = context.deref()?;
let db = context.seed.ctx.db.peek().await;
let package_id = package_id.unwrap_or_else(|| context.seed.id.clone());
]
}))
db.as_public()
.as_package_data()
.as_idx(&package_id)
.or_not_found(&package_id)?
.as_hosts()
.as_idx(&host_id)
.or_not_found(&host_id)?
.de()
}
async fn clear_bindings(context: EffectContext, _: Empty) -> Result<Value, Error> {
todo!()
async fn clear_bindings(context: EffectContext, _: Empty) -> Result<(), Error> {
let context = context.deref()?;
let mut svc = context.seed.persistent_container.net_service.lock().await;
svc.clear_bindings().await?;
Ok(())
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, TS)]
@@ -619,17 +553,15 @@ struct BindParams {
#[serde(flatten)]
options: BindOptions,
}
async fn bind(
context: EffectContext,
BindParams {
async fn bind(context: EffectContext, bind_params: Value) -> Result<(), Error> {
let BindParams {
kind,
id,
internal_port,
options,
}: BindParams,
) -> Result<(), Error> {
let ctx = context.deref()?;
let mut svc = ctx.persistent_container.net_service.lock().await;
} = from_value(bind_params)?;
let context = context.deref()?;
let mut svc = context.seed.persistent_container.net_service.lock().await;
svc.bind(kind, id, internal_port, options).await
}
@@ -639,39 +571,32 @@ async fn bind(
struct GetServiceInterfaceParams {
#[ts(type = "string | null")]
package_id: Option<PackageId>,
service_interface_id: String,
service_interface_id: ServiceInterfaceId,
callback: Callback,
}
async fn get_service_interface(
_: EffectContext,
context: EffectContext,
GetServiceInterfaceParams {
callback,
package_id,
service_interface_id,
}: GetServiceInterfaceParams,
) -> Result<Value, Error> {
// TODO @Dr_Bonez
Ok(json!({
"id": service_interface_id,
"name": service_interface_id,
"description": "This is a fake",
"hasPrimary": true,
"disabled": false,
"masked": false,
"addressInfo": json!({
"username": Value::Null,
"hostId": "HostId?",
"options": json!({
"scheme": Value::Null,
"preferredExternalPort": 80,
"addSsl":Value::Null,
"secure": false,
"ssl": false
}),
"suffix": "http"
}),
"type": "api"
}))
) -> Result<ServiceInterface, Error> {
let context = context.deref()?;
let package_id = package_id.unwrap_or_else(|| context.seed.id.clone());
let db = context.seed.ctx.db.peek().await;
let interface = db
.as_public()
.as_package_data()
.as_idx(&package_id)
.or_not_found(&package_id)?
.as_service_interfaces()
.as_idx(&service_interface_id)
.or_not_found(&service_interface_id)?
.de()?;
Ok(interface)
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Parser, TS)]
@@ -764,6 +689,7 @@ async fn get_ssl_certificate(
host_id,
}: GetSslCertificateParams,
) -> Result<Value, Error> {
// TODO
let fake = include_str!("./fake.cert.pem");
Ok(json!([fake, fake, fake]))
}
@@ -785,6 +711,7 @@ async fn get_ssl_key(
algorithm,
}: GetSslKeyParams,
) -> Result<Value, Error> {
// TODO
let fake = include_str!("./fake.cert.key");
Ok(json!(fake))
}
@@ -803,8 +730,8 @@ async fn get_store(
GetStoreParams { package_id, path }: GetStoreParams,
) -> Result<Value, Error> {
let context = context.deref()?;
let peeked = context.ctx.db.peek().await;
let package_id = package_id.unwrap_or(context.id.clone());
let peeked = context.seed.ctx.db.peek().await;
let package_id = package_id.unwrap_or(context.seed.id.clone());
let value = peeked
.as_private()
.as_package_stores()
@@ -832,8 +759,9 @@ async fn set_store(
SetStoreParams { value, path }: SetStoreParams,
) -> Result<(), Error> {
let context = context.deref()?;
let package_id = context.id.clone();
let package_id = context.seed.id.clone();
context
.seed
.ctx
.db
.mutate(|db| {
@@ -886,7 +814,7 @@ struct ParamsMaybePackageId {
async fn exists(context: EffectContext, params: ParamsPackageId) -> Result<Value, Error> {
let context = context.deref()?;
let peeked = context.ctx.db.peek().await;
let peeked = context.seed.ctx.db.peek().await;
let package = peeked
.as_public()
.as_package_data()
@@ -899,6 +827,8 @@ async fn exists(context: EffectContext, params: ParamsPackageId) -> Result<Value
#[serde(rename_all = "camelCase")]
#[ts(export)]
struct ExecuteAction {
#[serde(default)]
procedure_id: Guid,
#[ts(type = "string | null")]
service_id: Option<PackageId>,
#[ts(type = "string")]
@@ -909,30 +839,26 @@ struct ExecuteAction {
async fn execute_action(
context: EffectContext,
ExecuteAction {
procedure_id,
service_id,
action_id,
input,
service_id,
}: ExecuteAction,
) -> Result<Value, Error> {
let context = context.deref()?;
let package_id = service_id.clone().unwrap_or_else(|| context.id.clone());
let service = context.ctx.services.get(&package_id).await;
let service = service.as_ref().ok_or_else(|| {
Error::new(
eyre!("Could not find package {package_id}"),
ErrorKind::Unknown,
)
})?;
let package_id = service_id
.clone()
.unwrap_or_else(|| context.seed.id.clone());
Ok(json!(service.action(action_id, input).await?))
Ok(json!(context.action(procedure_id, action_id, input).await?))
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct FromService {}
async fn get_configured(context: EffectContext, _: Empty) -> Result<Value, Error> {
let context = context.deref()?;
let peeked = context.ctx.db.peek().await;
let package_id = &context.id;
let peeked = context.seed.ctx.db.peek().await;
let package_id = &context.seed.id;
let package = peeked
.as_public()
.as_package_data()
@@ -946,8 +872,8 @@ async fn get_configured(context: EffectContext, _: Empty) -> Result<Value, Error
async fn stopped(context: EffectContext, params: ParamsMaybePackageId) -> Result<Value, Error> {
let context = context.deref()?;
let peeked = context.ctx.db.peek().await;
let package_id = params.package_id.unwrap_or_else(|| context.id.clone());
let peeked = context.seed.ctx.db.peek().await;
let package_id = params.package_id.unwrap_or_else(|| context.seed.id.clone());
let package = peeked
.as_public()
.as_package_data()
@@ -959,9 +885,8 @@ async fn stopped(context: EffectContext, params: ParamsMaybePackageId) -> Result
Ok(json!(matches!(package, MainStatus::Stopped)))
}
async fn running(context: EffectContext, params: ParamsPackageId) -> Result<Value, Error> {
dbg!("Starting the running {params:?}");
let context = context.deref()?;
let peeked = context.ctx.db.peek().await;
let peeked = context.seed.ctx.db.peek().await;
let package_id = params.package_id;
let package = peeked
.as_public()
@@ -974,30 +899,66 @@ async fn running(context: EffectContext, params: ParamsPackageId) -> Result<Valu
Ok(json!(matches!(package, MainStatus::Running { .. })))
}
async fn restart(context: EffectContext, _: Empty) -> Result<Value, Error> {
let context = context.deref()?;
let service = context.ctx.services.get(&context.id).await;
let service = service.as_ref().ok_or_else(|| {
Error::new(
eyre!("Could not find package {}", context.id),
ErrorKind::Unknown,
)
})?;
service.restart().await?;
Ok(json!(()))
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
struct ProcedureId {
#[serde(default)]
procedure_id: Guid,
}
async fn shutdown(context: EffectContext, _: Empty) -> Result<Value, Error> {
let context = context.deref()?;
let service = context.ctx.services.get(&context.id).await;
let service = service.as_ref().ok_or_else(|| {
Error::new(
eyre!("Could not find package {}", context.id),
ErrorKind::Unknown,
impl FromArgMatches for ProcedureId {
fn from_arg_matches(matches: &clap::ArgMatches) -> Result<Self, clap::Error> {
Ok(Self {
procedure_id: matches.get_one("procedure-id").cloned().unwrap_or_default(),
})
}
fn from_arg_matches_mut(matches: &mut clap::ArgMatches) -> Result<Self, clap::Error> {
Ok(Self {
procedure_id: matches.get_one("procedure-id").cloned().unwrap_or_default(),
})
}
fn update_from_arg_matches(&mut self, matches: &clap::ArgMatches) -> Result<(), clap::Error> {
self.procedure_id = matches.get_one("procedure-id").cloned().unwrap_or_default();
Ok(())
}
fn update_from_arg_matches_mut(
&mut self,
matches: &mut clap::ArgMatches,
) -> Result<(), clap::Error> {
self.procedure_id = matches.get_one("procedure-id").cloned().unwrap_or_default();
Ok(())
}
}
impl CommandFactory for ProcedureId {
fn command() -> clap::Command {
Self::command_for_update().arg(
clap::Arg::new("procedure-id")
.action(clap::ArgAction::Set)
.value_parser(clap::value_parser!(Guid)),
)
})?;
service.stop().await?;
Ok(json!(()))
}
fn command_for_update() -> clap::Command {
Self::command()
}
}
async fn restart(
context: EffectContext,
ProcedureId { procedure_id }: ProcedureId,
) -> Result<(), Error> {
let context = context.deref()?;
context.restart(procedure_id).await?;
Ok(())
}
async fn shutdown(
context: EffectContext,
ProcedureId { procedure_id }: ProcedureId,
) -> Result<(), Error> {
let context = context.deref()?;
context.stop(procedure_id).await?;
Ok(())
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Parser, TS)]
@@ -1009,8 +970,9 @@ struct SetConfigured {
}
async fn set_configured(context: EffectContext, params: SetConfigured) -> Result<Value, Error> {
let context = context.deref()?;
let package_id = &context.id;
let package_id = &context.seed.id;
context
.seed
.ctx
.db
.mutate(|db| {
@@ -1032,7 +994,6 @@ async fn set_configured(context: EffectContext, params: SetConfigured) -> Result
enum SetMainStatusStatus {
Running,
Stopped,
Starting,
}
impl FromStr for SetMainStatusStatus {
type Err = color_eyre::eyre::Report;
@@ -1040,7 +1001,6 @@ impl FromStr for SetMainStatusStatus {
match s {
"running" => Ok(Self::Running),
"stopped" => Ok(Self::Stopped),
"starting" => Ok(Self::Starting),
_ => Err(eyre!("unknown status {s}")),
}
}
@@ -1060,12 +1020,10 @@ struct SetMainStatus {
status: SetMainStatusStatus,
}
async fn set_main_status(context: EffectContext, params: SetMainStatus) -> Result<Value, Error> {
dbg!(format!("Status for main will be is {params:?}"));
let context = context.deref()?;
match params.status {
SetMainStatusStatus::Running => context.started(),
SetMainStatusStatus::Stopped => context.stopped(),
SetMainStatusStatus::Starting => context.stopped(),
SetMainStatusStatus::Running => context.seed.started(),
SetMainStatusStatus::Stopped => context.seed.stopped(),
}
Ok(Value::Null)
}
@@ -1085,8 +1043,9 @@ async fn set_health(
) -> Result<Value, Error> {
let context = context.deref()?;
let package_id = &context.id;
let package_id = &context.seed.id;
context
.seed
.ctx
.db
.mutate(move |db| {
@@ -1115,17 +1074,17 @@ async fn set_health(
#[command(rename_all = "camelCase")]
#[ts(export)]
pub struct DestroyOverlayedImageParams {
#[ts(type = "string")]
guid: InternedString,
guid: Guid,
}
#[instrument(skip_all)]
pub async fn destroy_overlayed_image(
ctx: EffectContext,
context: EffectContext,
DestroyOverlayedImageParams { guid }: DestroyOverlayedImageParams,
) -> Result<(), Error> {
let ctx = ctx.deref()?;
if ctx
let context = context.deref()?;
if context
.seed
.persistent_container
.overlays
.lock()
@@ -1142,30 +1101,25 @@ pub async fn destroy_overlayed_image(
#[command(rename_all = "camelCase")]
#[ts(export)]
pub struct CreateOverlayedImageParams {
#[ts(type = "string")]
image_id: ImageId,
}
#[instrument(skip_all)]
pub async fn create_overlayed_image(
ctx: EffectContext,
context: EffectContext,
CreateOverlayedImageParams { image_id }: CreateOverlayedImageParams,
) -> Result<(PathBuf, InternedString), Error> {
let ctx = ctx.deref()?;
let path = Path::new("images")
.join(*ARCH)
.join(&image_id)
.with_extension("squashfs");
if let Some(image) = ctx
) -> Result<(PathBuf, Guid), Error> {
let context = context.deref()?;
if let Some(image) = context
.seed
.persistent_container
.s9pk
.as_archive()
.contents()
.get_path(&path)
.and_then(|e| e.as_file())
.images
.get(&image_id)
.cloned()
{
let guid = new_guid();
let rootfs_dir = ctx
let guid = Guid::new();
let rootfs_dir = context
.seed
.persistent_container
.lxc_container
.get()
@@ -1176,7 +1130,9 @@ pub async fn create_overlayed_image(
)
})?
.rootfs_dir();
let mountpoint = rootfs_dir.join("media/startos/overlays").join(&*guid);
let mountpoint = rootfs_dir
.join("media/startos/overlays")
.join(guid.as_ref());
tokio::fs::create_dir_all(&mountpoint).await?;
let container_mountpoint = Path::new("/").join(
mountpoint
@@ -1184,18 +1140,16 @@ pub async fn create_overlayed_image(
.with_kind(ErrorKind::Incoherent)?,
);
tracing::info!("Mounting overlay {guid} for {image_id}");
let guard = OverlayGuard::mount(
&IdMapped::new(LoopDev::from(&**image), 0, 100000, 65536),
&mountpoint,
)
.await?;
let guard = OverlayGuard::mount(image, &mountpoint).await?;
Command::new("chown")
.arg("100000:100000")
.arg(&mountpoint)
.invoke(ErrorKind::Filesystem)
.await?;
tracing::info!("Mounted overlay {guid} for {image_id}");
ctx.persistent_container
context
.seed
.persistent_container
.overlays
.lock()
.await
@@ -1298,17 +1252,21 @@ impl ValueParserFactory for DependencyRequirement {
#[command(rename_all = "camelCase")]
#[ts(export)]
struct SetDependenciesParams {
#[serde(default)]
procedure_id: Guid,
dependencies: Vec<DependencyRequirement>,
}
async fn set_dependencies(
ctx: EffectContext,
SetDependenciesParams { dependencies }: SetDependenciesParams,
context: EffectContext,
SetDependenciesParams {
procedure_id,
dependencies,
}: SetDependenciesParams,
) -> Result<(), Error> {
let ctx = ctx.deref()?;
let id = &ctx.id;
let service_guard = ctx.ctx.services.get(id).await;
let service = service_guard.as_ref().or_not_found(id)?;
let context = context.deref()?;
let id = &context.seed.id;
let mut deps = BTreeMap::new();
for dependency in dependencies {
let (dep_id, kind, registry_url, version_spec) = match dependency {
@@ -1338,14 +1296,13 @@ async fn set_dependencies(
let remote_s9pk = S9pk::deserialize(
&Arc::new(
HttpSource::new(
ctx.ctx.client.clone(),
context.seed.ctx.client.clone(),
registry_url
.join(&format!("package/v2/{}.s9pk?spec={}", dep_id, version_spec))?,
)
.await?,
),
None, // TODO
true,
)
.await?;
@@ -1365,14 +1322,19 @@ async fn set_dependencies(
)
}
};
let config_satisfied = if let Some(dep_service) = &*ctx.ctx.services.get(&dep_id).await {
service
.dependency_config(dep_id.clone(), dep_service.get_config().await?.config)
.await?
.is_none()
} else {
true
};
let config_satisfied =
if let Some(dep_service) = &*context.seed.ctx.services.get(&dep_id).await {
context
.dependency_config(
procedure_id.clone(),
dep_id.clone(),
dep_service.get_config(procedure_id.clone()).await?.config,
)
.await?
.is_none()
} else {
true
};
deps.insert(
dep_id,
CurrentDependencyInfo {
@@ -1385,7 +1347,9 @@ async fn set_dependencies(
},
);
}
ctx.ctx
context
.seed
.ctx
.db
.mutate(|db| {
db.as_public_mut()
@@ -1398,10 +1362,10 @@ async fn set_dependencies(
.await
}
async fn get_dependencies(ctx: EffectContext) -> Result<Vec<DependencyRequirement>, Error> {
let ctx = ctx.deref()?;
let id = &ctx.id;
let db = ctx.ctx.db.peek().await;
async fn get_dependencies(context: EffectContext) -> Result<Vec<DependencyRequirement>, Error> {
let context = context.deref()?;
let id = &context.seed.id;
let db = context.seed.ctx.db.peek().await;
let data = db
.as_public()
.as_package_data()
@@ -1458,16 +1422,16 @@ struct CheckDependenciesResult {
}
async fn check_dependencies(
ctx: EffectContext,
context: EffectContext,
CheckDependenciesParam { package_ids }: CheckDependenciesParam,
) -> Result<Vec<CheckDependenciesResult>, Error> {
let ctx = ctx.deref()?;
let db = ctx.ctx.db.peek().await;
let context = context.deref()?;
let db = context.seed.ctx.db.peek().await;
let current_dependencies = db
.as_public()
.as_package_data()
.as_idx(&ctx.id)
.or_not_found(&ctx.id)?
.as_idx(&context.seed.id)
.or_not_found(&context.seed.id)?
.as_current_dependencies()
.de()?;
let package_ids: Vec<_> = package_ids

View File

@@ -18,14 +18,11 @@ use crate::disk::mount::guard::GenericMountGuard;
use crate::install::PKG_ARCHIVE_DIR;
use crate::notifications::{notify, NotificationLevel};
use crate::prelude::*;
use crate::progress::{
FullProgressTracker, FullProgressTrackerHandle, PhaseProgressTrackerHandle,
ProgressTrackerWriter,
};
use crate::progress::{FullProgressTracker, PhaseProgressTrackerHandle, ProgressTrackerWriter};
use crate::s9pk::manifest::PackageId;
use crate::s9pk::merkle_archive::source::FileSource;
use crate::s9pk::S9pk;
use crate::service::{LoadDisposition, Service};
use crate::service::{LoadDisposition, Service, ServiceRef};
use crate::status::{MainStatus, Status};
use crate::util::serde::Pem;
@@ -34,39 +31,47 @@ pub type InstallFuture = BoxFuture<'static, Result<(), Error>>;
pub struct InstallProgressHandles {
pub finalization_progress: PhaseProgressTrackerHandle,
pub progress_handle: FullProgressTrackerHandle,
pub progress: FullProgressTracker,
}
/// This is the structure to contain all the services
#[derive(Default)]
pub struct ServiceMap(Mutex<OrdMap<PackageId, Arc<RwLock<Option<Service>>>>>);
pub struct ServiceMap(Mutex<OrdMap<PackageId, Arc<RwLock<Option<ServiceRef>>>>>);
impl ServiceMap {
async fn entry(&self, id: &PackageId) -> Arc<RwLock<Option<Service>>> {
async fn entry(&self, id: &PackageId) -> Arc<RwLock<Option<ServiceRef>>> {
let mut lock = self.0.lock().await;
dbg!(lock.keys().collect::<Vec<_>>());
lock.entry(id.clone())
.or_insert_with(|| Arc::new(RwLock::new(None)))
.clone()
}
#[instrument(skip_all)]
pub async fn get(&self, id: &PackageId) -> OwnedRwLockReadGuard<Option<Service>> {
pub async fn get(&self, id: &PackageId) -> OwnedRwLockReadGuard<Option<ServiceRef>> {
self.entry(id).await.read_owned().await
}
#[instrument(skip_all)]
pub async fn get_mut(&self, id: &PackageId) -> OwnedRwLockWriteGuard<Option<Service>> {
pub async fn get_mut(&self, id: &PackageId) -> OwnedRwLockWriteGuard<Option<ServiceRef>> {
self.entry(id).await.write_owned().await
}
#[instrument(skip_all)]
pub async fn init(&self, ctx: &RpcContext) -> Result<(), Error> {
for id in ctx.db.peek().await.as_public().as_package_data().keys()? {
pub async fn init(
&self,
ctx: &RpcContext,
mut progress: PhaseProgressTrackerHandle,
) -> Result<(), Error> {
progress.start();
let ids = ctx.db.peek().await.as_public().as_package_data().keys()?;
progress.set_total(ids.len() as u64);
for id in ids {
if let Err(e) = self.load(ctx, &id, LoadDisposition::Retry).await {
tracing::error!("Error loading installed package as service: {e}");
tracing::debug!("{e:?}");
}
progress += 1;
}
progress.complete();
Ok(())
}
@@ -83,7 +88,7 @@ impl ServiceMap {
shutdown_err = service.shutdown().await;
}
// TODO: retry on error?
*service = Service::load(ctx, id, disposition).await?;
*service = Service::load(ctx, id, disposition).await?.map(From::from);
shutdown_err?;
Ok(())
}
@@ -95,6 +100,7 @@ impl ServiceMap {
mut s9pk: S9pk<S>,
recovery_source: Option<impl GenericMountGuard>,
) -> Result<DownloadInstallFuture, Error> {
s9pk.validate_and_filter(ctx.s9pk_arch)?;
let manifest = s9pk.as_manifest().clone();
let id = manifest.id.clone();
let icon = s9pk.icon_data_url().await?;
@@ -112,23 +118,22 @@ impl ServiceMap {
};
let size = s9pk.size();
let mut progress = FullProgressTracker::new();
let progress = FullProgressTracker::new();
let download_progress_contribution = size.unwrap_or(60);
let progress_handle = progress.handle();
let mut download_progress = progress_handle.add_phase(
let mut download_progress = progress.add_phase(
InternedString::intern("Download"),
Some(download_progress_contribution),
);
if let Some(size) = size {
download_progress.set_total(size);
}
let mut finalization_progress = progress_handle.add_phase(
let mut finalization_progress = progress.add_phase(
InternedString::intern(op_name),
Some(download_progress_contribution / 2),
);
let restoring = recovery_source.is_some();
let mut reload_guard = ServiceReloadGuard::new(ctx.clone(), id.clone(), op_name);
let mut reload_guard = ServiceRefReloadGuard::new(ctx.clone(), id.clone(), op_name);
reload_guard
.handle(ctx.db.mutate({
@@ -194,7 +199,7 @@ impl ServiceMap {
let deref_id = id.clone();
let sync_progress_task =
NonDetachingJoinHandle::from(tokio::spawn(progress.sync_to_db(
NonDetachingJoinHandle::from(tokio::spawn(progress.clone().sync_to_db(
ctx.db.clone(),
move |v| {
v.as_public_mut()
@@ -231,7 +236,7 @@ impl ServiceMap {
Ok(reload_guard
.handle_last(async move {
finalization_progress.start();
let s9pk = S9pk::open(&installed_path, Some(&id), true).await?;
let s9pk = S9pk::open(&installed_path, Some(&id)).await?;
let prev = if let Some(service) = service.take() {
ensure_code!(
recovery_source.is_none(),
@@ -248,7 +253,7 @@ impl ServiceMap {
service
.uninstall(Some(s9pk.as_manifest().version.clone()))
.await?;
progress_handle.complete();
progress.complete();
Some(version)
} else {
None
@@ -261,10 +266,11 @@ impl ServiceMap {
recovery_source,
Some(InstallProgressHandles {
finalization_progress,
progress_handle,
progress,
}),
)
.await?,
.await?
.into(),
);
} else {
*service = Some(
@@ -274,10 +280,11 @@ impl ServiceMap {
prev,
Some(InstallProgressHandles {
finalization_progress,
progress_handle,
progress,
}),
)
.await?,
.await?
.into(),
);
}
sync_progress_task.await.map_err(|_| {
@@ -295,7 +302,7 @@ impl ServiceMap {
pub async fn uninstall(&self, ctx: &RpcContext, id: &PackageId) -> Result<(), Error> {
let mut guard = self.get_mut(id).await;
if let Some(service) = guard.take() {
ServiceReloadGuard::new(ctx.clone(), id.clone(), "Uninstall")
ServiceRefReloadGuard::new(ctx.clone(), id.clone(), "Uninstall")
.handle_last(async move {
let res = service.uninstall(None).await;
drop(guard);
@@ -326,17 +333,17 @@ impl ServiceMap {
}
}
pub struct ServiceReloadGuard(Option<ServiceReloadInfo>);
impl Drop for ServiceReloadGuard {
pub struct ServiceRefReloadGuard(Option<ServiceRefReloadInfo>);
impl Drop for ServiceRefReloadGuard {
fn drop(&mut self) {
if let Some(info) = self.0.take() {
tokio::spawn(info.reload(None));
}
}
}
impl ServiceReloadGuard {
impl ServiceRefReloadGuard {
pub fn new(ctx: RpcContext, id: PackageId, operation: &'static str) -> Self {
Self(Some(ServiceReloadInfo { ctx, id, operation }))
Self(Some(ServiceRefReloadInfo { ctx, id, operation }))
}
pub async fn handle<T>(
@@ -365,12 +372,12 @@ impl ServiceReloadGuard {
}
}
struct ServiceReloadInfo {
struct ServiceRefReloadInfo {
ctx: RpcContext,
id: PackageId,
operation: &'static str,
}
impl ServiceReloadInfo {
impl ServiceRefReloadInfo {
async fn reload(self, error: Option<Error>) -> Result<(), Error> {
self.ctx
.services

View File

@@ -6,6 +6,7 @@ use models::ProcedureName;
use super::TempDesiredRestore;
use crate::disk::mount::filesystem::ReadWrite;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::service::config::GetConfig;
use crate::service::dependencies::DependencyConfig;
use crate::service::transition::{TransitionKind, TransitionState};
@@ -24,7 +25,12 @@ impl Handler<Backup> for ServiceActor {
.except::<GetConfig>()
.except::<DependencyConfig>()
}
async fn handle(&mut self, backup: Backup, jobs: &BackgroundJobQueue) -> Self::Response {
async fn handle(
&mut self,
id: Guid,
backup: Backup,
jobs: &BackgroundJobQueue,
) -> Self::Response {
// So Need a handle to just a single field in the state
let temp: TempDesiredRestore = TempDesiredRestore::new(&self.0.persistent_container.state);
let mut current = self.0.persistent_container.state.subscribe();
@@ -45,7 +51,7 @@ impl Handler<Backup> for ServiceActor {
.mount_backup(path, ReadWrite)
.await?;
seed.persistent_container
.execute(ProcedureName::CreateBackup, Value::Null, None)
.execute(id, ProcedureName::CreateBackup, Value::Null, None)
.await?;
backup_guard.unmount(true).await?;

View File

@@ -2,6 +2,7 @@ use futures::FutureExt;
use super::TempDesiredRestore;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::service::config::GetConfig;
use crate::service::dependencies::DependencyConfig;
use crate::service::transition::{TransitionKind, TransitionState};
@@ -18,7 +19,8 @@ impl Handler<Restart> for ServiceActor {
.except::<GetConfig>()
.except::<DependencyConfig>()
}
async fn handle(&mut self, _: Restart, jobs: &BackgroundJobQueue) -> Self::Response {
async fn handle(&mut self, _: Guid, _: Restart, jobs: &BackgroundJobQueue) -> Self::Response {
dbg!("here");
// So Need a handle to just a single field in the state
let temp = TempDesiredRestore::new(&self.0.persistent_container.state);
let mut current = self.0.persistent_container.state.subscribe();
@@ -74,7 +76,8 @@ impl Handler<Restart> for ServiceActor {
}
impl Service {
#[instrument(skip_all)]
pub async fn restart(&self) -> Result<(), Error> {
self.actor.send(Restart).await
pub async fn restart(&self, id: Guid) -> Result<(), Error> {
dbg!("here");
self.actor.send(id, Restart).await
}
}

View File

@@ -5,6 +5,7 @@ use models::ProcedureName;
use crate::disk::mount::filesystem::ReadOnly;
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::service::transition::{TransitionKind, TransitionState};
use crate::service::ServiceActor;
use crate::util::actor::background::BackgroundJobQueue;
@@ -19,7 +20,12 @@ impl Handler<Restore> for ServiceActor {
fn conflicts_with(_: &Restore) -> ConflictBuilder<Self> {
ConflictBuilder::everything()
}
async fn handle(&mut self, restore: Restore, jobs: &BackgroundJobQueue) -> Self::Response {
async fn handle(
&mut self,
id: Guid,
restore: Restore,
jobs: &BackgroundJobQueue,
) -> Self::Response {
// So Need a handle to just a single field in the state
let path = restore.path.clone();
let seed = self.0.clone();
@@ -32,7 +38,7 @@ impl Handler<Restore> for ServiceActor {
.mount_backup(path, ReadOnly)
.await?;
seed.persistent_container
.execute(ProcedureName::RestoreBackup, Value::Null, None)
.execute(id, ProcedureName::RestoreBackup, Value::Null, None)
.await?;
backup_guard.unmount(true).await?;

View File

@@ -4,7 +4,6 @@ use std::time::Duration;
use color_eyre::eyre::eyre;
use josekit::jwk::Jwk;
use openssl::x509::X509;
use patch_db::json_ptr::ROOT;
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{from_fn_async, Context, HandlerExt, ParentHandler};
@@ -12,15 +11,15 @@ use serde::{Deserialize, Serialize};
use tokio::fs::File;
use tokio::io::AsyncWriteExt;
use tokio::try_join;
use torut::onion::OnionAddressV3;
use tracing::instrument;
use ts_rs::TS;
use crate::account::AccountInfo;
use crate::backup::restore::recover_full_embassy;
use crate::backup::target::BackupTargetFS;
use crate::context::rpc::InitRpcContextPhases;
use crate::context::setup::SetupResult;
use crate::context::SetupContext;
use crate::context::{RpcContext, SetupContext};
use crate::db::model::Database;
use crate::disk::fsck::RepairStrategy;
use crate::disk::main::DEFAULT_PASSWORD;
@@ -29,10 +28,12 @@ use crate::disk::mount::filesystem::ReadWrite;
use crate::disk::mount::guard::{GenericMountGuard, TmpMountGuard};
use crate::disk::util::{pvscan, recovery_info, DiskInfo, EmbassyOsRecoveryInfo};
use crate::disk::REPAIR_DISK_PATH;
use crate::hostname::Hostname;
use crate::init::{init, InitResult};
use crate::init::{init, InitPhases, InitResult};
use crate::net::net_controller::PreInitNetController;
use crate::net::ssl::root_ca_start_time;
use crate::prelude::*;
use crate::progress::{FullProgress, PhaseProgressTrackerHandle};
use crate::rpc_continuations::Guid;
use crate::util::crypto::EncryptedWire;
use crate::util::io::{dir_copy, dir_size, Counter};
use crate::{Error, ErrorKind, ResultExt};
@@ -75,10 +76,12 @@ pub async fn list_disks(ctx: SetupContext) -> Result<Vec<DiskInfo>, Error> {
async fn setup_init(
ctx: &SetupContext,
password: Option<String>,
) -> Result<(Hostname, OnionAddressV3, X509), Error> {
let InitResult { db } = init(&ctx.config).await?;
init_phases: InitPhases,
) -> Result<(AccountInfo, PreInitNetController), Error> {
let InitResult { net_ctrl } = init(&ctx.config, init_phases).await?;
let account = db
let account = net_ctrl
.db
.mutate(|m| {
let mut account = AccountInfo::load(m)?;
if let Some(password) = password {
@@ -93,15 +96,12 @@ async fn setup_init(
})
.await?;
Ok((
account.hostname,
account.tor_key.public().get_onion_address(),
account.root_ca_cert,
))
Ok((account, net_ctrl))
}
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct AttachParams {
#[serde(rename = "startOsPassword")]
password: Option<EncryptedWire>,
@@ -110,25 +110,20 @@ pub struct AttachParams {
pub async fn attach(
ctx: SetupContext,
AttachParams { password, guid }: AttachParams,
) -> Result<(), Error> {
let mut status = ctx.setup_status.write().await;
if status.is_some() {
return Err(Error::new(
eyre!("Setup already in progress"),
ErrorKind::InvalidRequest,
));
}
*status = Some(Ok(SetupStatus {
bytes_transferred: 0,
total_bytes: None,
complete: false,
}));
drop(status);
tokio::task::spawn(async move {
if let Err(e) = async {
AttachParams {
password,
guid: disk_guid,
}: AttachParams,
) -> Result<SetupProgress, Error> {
let setup_ctx = ctx.clone();
ctx.run_setup(|| async move {
let progress = &setup_ctx.progress;
let mut disk_phase = progress.add_phase("Opening data drive".into(), Some(10));
let init_phases = InitPhases::new(&progress);
let rpc_ctx_phases = InitRpcContextPhases::new(&progress);
let password: Option<String> = match password {
Some(a) => match a.decrypt(&*ctx) {
Some(a) => match a.decrypt(&setup_ctx) {
a @ Some(_) => a,
None => {
return Err(Error::new(
@@ -139,15 +134,17 @@ pub async fn attach(
},
None => None,
};
disk_phase.start();
let requires_reboot = crate::disk::main::import(
&*guid,
&ctx.datadir,
&*disk_guid,
&setup_ctx.datadir,
if tokio::fs::metadata(REPAIR_DISK_PATH).await.is_ok() {
RepairStrategy::Aggressive
} else {
RepairStrategy::Preen
},
if guid.ends_with("_UNENC") { None } else { Some(DEFAULT_PASSWORD) },
if disk_guid.ends_with("_UNENC") { None } else { Some(DEFAULT_PASSWORD) },
)
.await?;
if tokio::fs::metadata(REPAIR_DISK_PATH).await.is_ok() {
@@ -156,7 +153,7 @@ pub async fn attach(
.with_ctx(|_| (ErrorKind::Filesystem, REPAIR_DISK_PATH))?;
}
if requires_reboot.0 {
crate::disk::main::export(&*guid, &ctx.datadir).await?;
crate::disk::main::export(&*disk_guid, &setup_ctx.datadir).await?;
return Err(Error::new(
eyre!(
"Errors were corrected with your disk, but the server must be restarted in order to proceed"
@@ -164,37 +161,48 @@ pub async fn attach(
ErrorKind::DiskManagement,
));
}
let (hostname, tor_addr, root_ca) = setup_init(&ctx, password).await?;
*ctx.setup_result.write().await = Some((guid, SetupResult {
tor_address: format!("https://{}", tor_addr),
lan_address: hostname.lan_address(),
root_ca: String::from_utf8(root_ca.to_pem()?)?,
}));
*ctx.setup_status.write().await = Some(Ok(SetupStatus {
bytes_transferred: 0,
total_bytes: None,
complete: true,
}));
Ok(())
}.await {
tracing::error!("Error Setting Up Embassy: {}", e);
tracing::debug!("{:?}", e);
*ctx.setup_status.write().await = Some(Err(e.into()));
}
});
Ok(())
disk_phase.complete();
let (account, net_ctrl) = setup_init(&setup_ctx, password, init_phases).await?;
let rpc_ctx = RpcContext::init(&setup_ctx.config, disk_guid, Some(net_ctrl), rpc_ctx_phases).await?;
Ok(((&account).try_into()?, rpc_ctx))
})?;
Ok(ctx.progress().await)
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug, Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct SetupStatus {
pub bytes_transferred: u64,
pub total_bytes: Option<u64>,
pub complete: bool,
#[ts(export)]
#[serde(tag = "status")]
pub enum SetupStatusRes {
Complete(SetupResult),
Running(SetupProgress),
}
pub async fn status(ctx: SetupContext) -> Result<Option<SetupStatus>, RpcError> {
ctx.setup_status.read().await.clone().transpose()
#[derive(Debug, Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct SetupProgress {
pub progress: FullProgress,
pub guid: Guid,
}
pub async fn status(ctx: SetupContext) -> Result<Option<SetupStatusRes>, Error> {
if let Some(res) = ctx.result.get() {
match res {
Ok((res, _)) => Ok(Some(SetupStatusRes::Complete(res.clone()))),
Err(e) => Err(e.clone_output()),
}
} else {
if ctx.task.initialized() {
Ok(Some(SetupStatusRes::Running(ctx.progress().await)))
} else {
Ok(None)
}
}
}
/// We want to be able to get a secret, a shared private key with the frontend
@@ -202,7 +210,7 @@ pub async fn status(ctx: SetupContext) -> Result<Option<SetupStatus>, RpcError>
/// without knowing the password over clearnet. We use the public key shared across the network
/// since it is fine to share the public, and encrypt against the public.
pub async fn get_pubkey(ctx: SetupContext) -> Result<Jwk, RpcError> {
let secret = ctx.as_ref().clone();
let secret = AsRef::<Jwk>::as_ref(&ctx).clone();
let pub_key = secret.to_public_key()?;
Ok(pub_key)
}
@@ -213,6 +221,7 @@ pub fn cifs<C: Context>() -> ParentHandler<C> {
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export)]
pub struct VerifyCifsParams {
hostname: String,
path: PathBuf,
@@ -230,7 +239,7 @@ pub async fn verify_cifs(
password,
}: VerifyCifsParams,
) -> Result<EmbassyOsRecoveryInfo, Error> {
let password: Option<String> = password.map(|x| x.decrypt(&*ctx)).flatten();
let password: Option<String> = password.map(|x| x.decrypt(&ctx)).flatten();
let guard = TmpMountGuard::mount(
&Cifs {
hostname,
@@ -256,7 +265,8 @@ pub enum RecoverySource {
#[derive(Deserialize, Serialize, TS)]
#[serde(rename_all = "camelCase")]
pub struct ExecuteParams {
#[ts(export)]
pub struct SetupExecuteParams {
start_os_logicalname: PathBuf,
start_os_password: EncryptedWire,
recovery_source: Option<RecoverySource>,
@@ -266,104 +276,65 @@ pub struct ExecuteParams {
// #[command(rpc_only)]
pub async fn execute(
ctx: SetupContext,
ExecuteParams {
SetupExecuteParams {
start_os_logicalname,
start_os_password,
recovery_source,
recovery_password,
}: ExecuteParams,
) -> Result<(), Error> {
let start_os_password = match start_os_password.decrypt(&*ctx) {
}: SetupExecuteParams,
) -> Result<SetupProgress, Error> {
let start_os_password = match start_os_password.decrypt(&ctx) {
Some(a) => a,
None => {
return Err(Error::new(
color_eyre::eyre::eyre!("Couldn't decode embassy-password"),
color_eyre::eyre::eyre!("Couldn't decode startOsPassword"),
crate::ErrorKind::Unknown,
))
}
};
let recovery_password: Option<String> = match recovery_password {
Some(a) => match a.decrypt(&*ctx) {
Some(a) => match a.decrypt(&ctx) {
Some(a) => Some(a),
None => {
return Err(Error::new(
color_eyre::eyre::eyre!("Couldn't decode recovery-password"),
color_eyre::eyre::eyre!("Couldn't decode recoveryPassword"),
crate::ErrorKind::Unknown,
))
}
},
None => None,
};
let mut status = ctx.setup_status.write().await;
if status.is_some() {
return Err(Error::new(
eyre!("Setup already in progress"),
ErrorKind::InvalidRequest,
));
}
*status = Some(Ok(SetupStatus {
bytes_transferred: 0,
total_bytes: None,
complete: false,
}));
drop(status);
tokio::task::spawn({
async move {
let ctx = ctx.clone();
match execute_inner(
ctx.clone(),
start_os_logicalname,
start_os_password,
recovery_source,
recovery_password,
)
.await
{
Ok((guid, hostname, tor_addr, root_ca)) => {
tracing::info!("Setup Complete!");
*ctx.setup_result.write().await = Some((
guid,
SetupResult {
tor_address: format!("https://{}", tor_addr),
lan_address: hostname.lan_address(),
root_ca: String::from_utf8(
root_ca.to_pem().expect("failed to serialize root ca"),
)
.expect("invalid pem string"),
},
));
*ctx.setup_status.write().await = Some(Ok(SetupStatus {
bytes_transferred: 0,
total_bytes: None,
complete: true,
}));
}
Err(e) => {
tracing::error!("Error Setting Up Server: {}", e);
tracing::debug!("{:?}", e);
*ctx.setup_status.write().await = Some(Err(e.into()));
}
}
}
});
Ok(())
let setup_ctx = ctx.clone();
ctx.run_setup(|| {
execute_inner(
setup_ctx,
start_os_logicalname,
start_os_password,
recovery_source,
recovery_password,
)
})?;
Ok(ctx.progress().await)
}
#[instrument(skip_all)]
// #[command(rpc_only)]
pub async fn complete(ctx: SetupContext) -> Result<SetupResult, Error> {
let (guid, setup_result) = if let Some((guid, setup_result)) = &*ctx.setup_result.read().await {
(guid.clone(), setup_result.clone())
} else {
return Err(Error::new(
match ctx.result.get() {
Some(Ok((res, ctx))) => {
let mut guid_file = File::create("/media/startos/config/disk.guid").await?;
guid_file.write_all(ctx.disk_guid.as_bytes()).await?;
guid_file.sync_all().await?;
Ok(res.clone())
}
Some(Err(e)) => Err(e.clone_output()),
None => Err(Error::new(
eyre!("setup.execute has not completed successfully"),
crate::ErrorKind::InvalidRequest,
));
};
let mut guid_file = File::create("/media/startos/config/disk.guid").await?;
guid_file.write_all(guid.as_bytes()).await?;
guid_file.sync_all().await?;
Ok(setup_result)
)),
}
}
#[instrument(skip_all)]
@@ -380,7 +351,22 @@ pub async fn execute_inner(
start_os_password: String,
recovery_source: Option<RecoverySource>,
recovery_password: Option<String>,
) -> Result<(Arc<String>, Hostname, OnionAddressV3, X509), Error> {
) -> Result<(SetupResult, RpcContext), Error> {
let progress = &ctx.progress;
let mut disk_phase = progress.add_phase("Formatting data drive".into(), Some(10));
let restore_phase = match &recovery_source {
Some(RecoverySource::Backup { .. }) => {
Some(progress.add_phase("Restoring backup".into(), Some(100)))
}
Some(RecoverySource::Migrate { .. }) => {
Some(progress.add_phase("Transferring data".into(), Some(100)))
}
None => None,
};
let init_phases = InitPhases::new(&progress);
let rpc_ctx_phases = InitRpcContextPhases::new(&progress);
disk_phase.start();
let encryption_password = if ctx.disable_encryption {
None
} else {
@@ -402,41 +388,70 @@ pub async fn execute_inner(
encryption_password,
)
.await?;
disk_phase.complete();
if let Some(RecoverySource::Backup { target }) = recovery_source {
recover(ctx, guid, start_os_password, target, recovery_password).await
} else if let Some(RecoverySource::Migrate { guid: old_guid }) = recovery_source {
migrate(ctx, guid, &old_guid, start_os_password).await
} else {
let (hostname, tor_addr, root_ca) = fresh_setup(&ctx, &start_os_password).await?;
Ok((guid, hostname, tor_addr, root_ca))
let progress = SetupExecuteProgress {
init_phases,
restore_phase,
rpc_ctx_phases,
};
match recovery_source {
Some(RecoverySource::Backup { target }) => {
recover(
&ctx,
guid,
start_os_password,
target,
recovery_password,
progress,
)
.await
}
Some(RecoverySource::Migrate { guid: old_guid }) => {
migrate(&ctx, guid, &old_guid, start_os_password, progress).await
}
None => fresh_setup(&ctx, guid, &start_os_password, progress).await,
}
}
pub struct SetupExecuteProgress {
pub init_phases: InitPhases,
pub restore_phase: Option<PhaseProgressTrackerHandle>,
pub rpc_ctx_phases: InitRpcContextPhases,
}
async fn fresh_setup(
ctx: &SetupContext,
guid: Arc<String>,
start_os_password: &str,
) -> Result<(Hostname, OnionAddressV3, X509), Error> {
SetupExecuteProgress {
init_phases,
rpc_ctx_phases,
..
}: SetupExecuteProgress,
) -> Result<(SetupResult, RpcContext), Error> {
let account = AccountInfo::new(start_os_password, root_ca_start_time().await?)?;
let db = ctx.db().await?;
db.put(&ROOT, &Database::init(&account)?).await?;
drop(db);
init(&ctx.config).await?;
Ok((
account.hostname,
account.tor_key.public().get_onion_address(),
account.root_ca_cert,
))
let InitResult { net_ctrl } = init(&ctx.config, init_phases).await?;
let rpc_ctx = RpcContext::init(&ctx.config, guid, Some(net_ctrl), rpc_ctx_phases).await?;
Ok(((&account).try_into()?, rpc_ctx))
}
#[instrument(skip_all)]
async fn recover(
ctx: SetupContext,
ctx: &SetupContext,
guid: Arc<String>,
start_os_password: String,
recovery_source: BackupTargetFS,
recovery_password: Option<String>,
) -> Result<(Arc<String>, Hostname, OnionAddressV3, X509), Error> {
progress: SetupExecuteProgress,
) -> Result<(SetupResult, RpcContext), Error> {
let recovery_source = TmpMountGuard::mount(&recovery_source, ReadWrite).await?;
recover_full_embassy(
ctx,
@@ -444,23 +459,26 @@ async fn recover(
start_os_password,
recovery_source,
recovery_password,
progress,
)
.await
}
#[instrument(skip_all)]
async fn migrate(
ctx: SetupContext,
ctx: &SetupContext,
guid: Arc<String>,
old_guid: &str,
start_os_password: String,
) -> Result<(Arc<String>, Hostname, OnionAddressV3, X509), Error> {
*ctx.setup_status.write().await = Some(Ok(SetupStatus {
bytes_transferred: 0,
total_bytes: None,
complete: false,
}));
SetupExecuteProgress {
init_phases,
restore_phase,
rpc_ctx_phases,
}: SetupExecuteProgress,
) -> Result<(SetupResult, RpcContext), Error> {
let mut restore_phase = restore_phase.or_not_found("restore progress")?;
restore_phase.start();
let _ = crate::disk::main::import(
&old_guid,
"/media/startos/migrate",
@@ -500,20 +518,12 @@ async fn migrate(
res = async {
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
*ctx.setup_status.write().await = Some(Ok(SetupStatus {
bytes_transferred: 0,
total_bytes: Some(main_transfer_size.load() + package_data_transfer_size.load()),
complete: false,
}));
restore_phase.set_total(main_transfer_size.load() + package_data_transfer_size.load());
}
} => res,
};
*ctx.setup_status.write().await = Some(Ok(SetupStatus {
bytes_transferred: 0,
total_bytes: Some(size),
complete: false,
}));
restore_phase.set_total(size);
let main_transfer_progress = Counter::new(0, ordering);
let package_data_transfer_progress = Counter::new(0, ordering);
@@ -529,18 +539,17 @@ async fn migrate(
res = async {
loop {
tokio::time::sleep(Duration::from_secs(1)).await;
*ctx.setup_status.write().await = Some(Ok(SetupStatus {
bytes_transferred: main_transfer_progress.load() + package_data_transfer_progress.load(),
total_bytes: Some(size),
complete: false,
}));
restore_phase.set_done(main_transfer_progress.load() + package_data_transfer_progress.load());
}
} => res,
}
let (hostname, tor_addr, root_ca) = setup_init(&ctx, Some(start_os_password)).await?;
crate::disk::main::export(&old_guid, "/media/startos/migrate").await?;
restore_phase.complete();
Ok((guid, hostname, tor_addr, root_ca))
let (account, net_ctrl) = setup_init(&ctx, Some(start_os_password), init_phases).await?;
let rpc_ctx = RpcContext::init(&ctx.config, guid, Some(net_ctrl), rpc_ctx_phases).await?;
Ok(((&account).try_into()?, rpc_ctx))
}

View File

@@ -20,9 +20,7 @@ use ts_rs::TS;
use crate::context::{CliContext, RpcContext};
use crate::notifications::{notify, NotificationLevel};
use crate::prelude::*;
use crate::progress::{
FullProgressTracker, FullProgressTrackerHandle, PhaseProgressTrackerHandle, PhasedProgressBar,
};
use crate::progress::{FullProgressTracker, PhaseProgressTrackerHandle, PhasedProgressBar};
use crate::registry::asset::RegistryAsset;
use crate::registry::context::{RegistryContext, RegistryUrlParams};
use crate::registry::os::index::OsVersionInfo;
@@ -34,6 +32,7 @@ use crate::s9pk::merkle_archive::source::multi_cursor_file::MultiCursorFile;
use crate::sound::{
CIRCLE_OF_5THS_SHORT, UPDATE_FAILED_1, UPDATE_FAILED_2, UPDATE_FAILED_3, UPDATE_FAILED_4,
};
use crate::util::net::WebSocketExt;
use crate::util::Invoke;
use crate::PLATFORM;
@@ -91,50 +90,47 @@ pub async fn update_system(
.add(
guid.clone(),
RpcContinuation::ws(
Box::new(|mut ws| {
async move {
if let Err(e) = async {
let mut sub = ctx
|mut ws| async move {
if let Err(e) = async {
let mut sub = ctx
.db
.subscribe(
"/public/serverInfo/statusInfo/updateProgress"
.parse::<JsonPointer>()
.with_kind(ErrorKind::Database)?,
)
.await;
while {
let progress = ctx
.db
.subscribe(
"/public/serverInfo/statusInfo/updateProgress"
.parse::<JsonPointer>()
.with_kind(ErrorKind::Database)?,
)
.await;
while {
let progress = ctx
.db
.peek()
.await
.into_public()
.into_server_info()
.into_status_info()
.into_update_progress()
.de()?;
ws.send(axum::extract::ws::Message::Text(
serde_json::to_string(&progress)
.with_kind(ErrorKind::Serialization)?,
))
.peek()
.await
.with_kind(ErrorKind::Network)?;
progress.is_some()
} {
sub.recv().await;
}
ws.close().await.with_kind(ErrorKind::Network)?;
Ok::<_, Error>(())
}
.await
{
tracing::error!("Error returning progress of update: {e}");
tracing::debug!("{e:?}")
.into_public()
.into_server_info()
.into_status_info()
.into_update_progress()
.de()?;
ws.send(axum::extract::ws::Message::Text(
serde_json::to_string(&progress)
.with_kind(ErrorKind::Serialization)?,
))
.await
.with_kind(ErrorKind::Network)?;
progress.is_some()
} {
sub.recv().await;
}
ws.normal_close("complete").await?;
Ok::<_, Error>(())
}
.boxed()
}),
.await
{
tracing::error!("Error returning progress of update: {e}");
tracing::debug!("{e:?}")
}
},
Duration::from_secs(30),
),
)
@@ -250,13 +246,12 @@ async fn maybe_do_update(
asset.validate(SIG_CONTEXT, asset.all_signers())?;
let mut progress = FullProgressTracker::new();
let progress_handle = progress.handle();
let mut download_phase = progress_handle.add_phase("Downloading File".into(), Some(100));
let progress = FullProgressTracker::new();
let mut download_phase = progress.add_phase("Downloading File".into(), Some(100));
download_phase.set_total(asset.commitment.size);
let reverify_phase = progress_handle.add_phase("Reverifying File".into(), Some(10));
let sync_boot_phase = progress_handle.add_phase("Syncing Boot Files".into(), Some(1));
let finalize_phase = progress_handle.add_phase("Finalizing Update".into(), Some(1));
let reverify_phase = progress.add_phase("Reverifying File".into(), Some(10));
let sync_boot_phase = progress.add_phase("Syncing Boot Files".into(), Some(1));
let finalize_phase = progress.add_phase("Finalizing Update".into(), Some(1));
let start_progress = progress.snapshot();
@@ -287,7 +282,7 @@ async fn maybe_do_update(
));
}
let progress_task = NonDetachingJoinHandle::from(tokio::spawn(progress.sync_to_db(
let progress_task = NonDetachingJoinHandle::from(tokio::spawn(progress.clone().sync_to_db(
ctx.db.clone(),
|db| {
db.as_public_mut()
@@ -304,7 +299,7 @@ async fn maybe_do_update(
ctx.clone(),
asset,
UpdateProgressHandles {
progress_handle,
progress,
download_phase,
reverify_phase,
sync_boot_phase,
@@ -373,7 +368,7 @@ async fn maybe_do_update(
}
struct UpdateProgressHandles {
progress_handle: FullProgressTrackerHandle,
progress: FullProgressTracker,
download_phase: PhaseProgressTrackerHandle,
reverify_phase: PhaseProgressTrackerHandle,
sync_boot_phase: PhaseProgressTrackerHandle,
@@ -385,7 +380,7 @@ async fn do_update(
ctx: RpcContext,
asset: RegistryAsset<Blake3Commitment>,
UpdateProgressHandles {
progress_handle,
progress,
mut download_phase,
mut reverify_phase,
mut sync_boot_phase,
@@ -436,7 +431,7 @@ async fn do_update(
.await?;
finalize_phase.complete();
progress_handle.complete();
progress.complete();
Ok(())
}

View File

@@ -5,9 +5,10 @@ use std::time::Duration;
use axum::body::Body;
use axum::response::Response;
use futures::{FutureExt, StreamExt};
use futures::StreamExt;
use http::header::CONTENT_LENGTH;
use http::StatusCode;
use imbl_value::InternedString;
use tokio::fs::File;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::sync::watch;
@@ -19,68 +20,70 @@ use crate::s9pk::merkle_archive::source::multi_cursor_file::MultiCursorFile;
use crate::s9pk::merkle_archive::source::ArchiveSource;
use crate::util::io::TmpDir;
pub async fn upload(ctx: &RpcContext) -> Result<(Guid, UploadingFile), Error> {
pub async fn upload(
ctx: &RpcContext,
session: InternedString,
) -> Result<(Guid, UploadingFile), Error> {
let guid = Guid::new();
let (mut handle, file) = UploadingFile::new().await?;
ctx.rpc_continuations
.add(
guid.clone(),
RpcContinuation::rest(
Box::new(|request| {
async move {
let headers = request.headers();
let content_length = match headers.get(CONTENT_LENGTH).map(|a| a.to_str()) {
None => {
return Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Content-Length is required"))
.with_kind(ErrorKind::Network)
}
Some(Err(_)) => {
RpcContinuation::rest_authed(
ctx,
session,
|request| async move {
let headers = request.headers();
let content_length = match headers.get(CONTENT_LENGTH).map(|a| a.to_str()) {
None => {
return Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Content-Length is required"))
.with_kind(ErrorKind::Network)
}
Some(Err(_)) => {
return Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid Content-Length"))
.with_kind(ErrorKind::Network)
}
Some(Ok(a)) => match a.parse::<u64>() {
Err(_) => {
return Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid Content-Length"))
.with_kind(ErrorKind::Network)
}
Some(Ok(a)) => match a.parse::<u64>() {
Err(_) => {
return Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid Content-Length"))
.with_kind(ErrorKind::Network)
}
Ok(a) => a,
},
};
Ok(a) => a,
},
};
handle
.progress
.send_modify(|p| p.expected_size = Some(content_length));
handle
.progress
.send_modify(|p| p.expected_size = Some(content_length));
let mut body = request.into_body().into_data_stream();
while let Some(next) = body.next().await {
if let Err(e) = async {
handle
.write_all(&next.map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, e)
})?)
.await?;
Ok(())
}
.await
{
handle.progress.send_if_modified(|p| p.handle_error(&e));
break;
}
let mut body = request.into_body().into_data_stream();
while let Some(next) = body.next().await {
if let Err(e) = async {
handle
.write_all(&next.map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, e)
})?)
.await?;
Ok(())
}
.await
{
handle.progress.send_if_modified(|p| p.handle_error(&e));
break;
}
Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Body::empty())
.with_kind(ErrorKind::Network)
}
.boxed()
}),
Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Body::empty())
.with_kind(ErrorKind::Network)
},
Duration::from_secs(30),
),
)

View File

@@ -8,6 +8,7 @@ use helpers::NonDetachingJoinHandle;
use tokio::sync::{mpsc, oneshot};
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::util::actor::background::{BackgroundJobQueue, BackgroundJobRunner};
use crate::util::actor::{Actor, ConflictFn, Handler, PendingMessageStrategy, Request};
@@ -18,6 +19,7 @@ struct ConcurrentRunner<A> {
waiting: Vec<Request<A>>,
recv: mpsc::UnboundedReceiver<Request<A>>,
handlers: Vec<(
Guid,
Arc<ConflictFn<A>>,
oneshot::Sender<Box<dyn Any + Send>>,
BoxFuture<'static, Box<dyn Any + Send>>,
@@ -41,16 +43,21 @@ impl<A: Actor + Clone> Future for ConcurrentRunner<A> {
}
});
if this.shutdown.is_some() {
while let std::task::Poll::Ready(Some((msg, reply))) = this.recv.poll_recv(cx) {
if this.handlers.iter().any(|(f, _, _)| f(&*msg)) {
this.waiting.push((msg, reply));
while let std::task::Poll::Ready(Some((id, msg, reply))) = this.recv.poll_recv(cx) {
if this
.handlers
.iter()
.any(|(hid, f, _, _)| &id != hid && f(&*msg))
{
this.waiting.push((id, msg, reply));
} else {
let mut actor = this.actor.clone();
let queue = this.queue.clone();
this.handlers.push((
id.clone(),
msg.conflicts_with(),
reply,
async move { msg.handle_with(&mut actor, &queue).await }.boxed(),
async move { msg.handle_with(id, &mut actor, &queue).await }.boxed(),
))
}
}
@@ -62,29 +69,34 @@ impl<A: Actor + Clone> Future for ConcurrentRunner<A> {
.handlers
.iter_mut()
.enumerate()
.filter_map(|(i, (_, _, f))| match f.poll_unpin(cx) {
.filter_map(|(i, (_, _, _, f))| match f.poll_unpin(cx) {
std::task::Poll::Pending => None,
std::task::Poll::Ready(res) => Some((i, res)),
})
.collect::<Vec<_>>();
for (idx, res) in complete.into_iter().rev() {
#[allow(clippy::let_underscore_future)]
let (f, reply, _) = this.handlers.swap_remove(idx);
let (_, f, reply, _) = this.handlers.swap_remove(idx);
let _ = reply.send(res);
// TODO: replace with Vec::extract_if once stable
if this.shutdown.is_some() {
let mut i = 0;
while i < this.waiting.len() {
if f(&*this.waiting[i].0)
&& !this.handlers.iter().any(|(f, _, _)| f(&*this.waiting[i].0))
if f(&*this.waiting[i].1)
&& !this
.handlers
.iter()
.any(|(_, f, _, _)| f(&*this.waiting[i].1))
{
let (msg, reply) = this.waiting.remove(i);
let (id, msg, reply) = this.waiting.remove(i);
let mut actor = this.actor.clone();
let queue = this.queue.clone();
this.handlers.push((
id.clone(),
msg.conflicts_with(),
reply,
async move { msg.handle_with(&mut actor, &queue).await }.boxed(),
async move { msg.handle_with(id, &mut actor, &queue).await }
.boxed(),
));
cont = true;
} else {
@@ -137,6 +149,7 @@ impl<A: Actor + Clone> ConcurrentActor<A> {
/// Message is guaranteed to be queued immediately
pub fn queue<M: Send + 'static>(
&self,
id: Guid,
message: M,
) -> impl Future<Output = Result<A::Response, Error>>
where
@@ -150,7 +163,7 @@ impl<A: Actor + Clone> ConcurrentActor<A> {
}
let (reply_send, reply_recv) = oneshot::channel();
self.messenger
.send((Box::new(message), reply_send))
.send((id, Box::new(message), reply_send))
.unwrap();
futures::future::Either::Right(
reply_recv
@@ -170,11 +183,11 @@ impl<A: Actor + Clone> ConcurrentActor<A> {
)
}
pub async fn send<M: Send + 'static>(&self, message: M) -> Result<A::Response, Error>
pub async fn send<M: Send + 'static>(&self, id: Guid, message: M) -> Result<A::Response, Error>
where
A: Handler<M>,
{
self.queue(message).await
self.queue(id, message).await
}
pub async fn shutdown(self, strategy: PendingMessageStrategy) {

View File

@@ -9,6 +9,7 @@ use tokio::sync::oneshot;
#[allow(unused_imports)]
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::util::actor::background::BackgroundJobQueue;
pub mod background;
@@ -28,6 +29,7 @@ pub trait Handler<M: Any + Send>: Actor {
}
fn handle(
&mut self,
id: Guid,
msg: M,
jobs: &BackgroundJobQueue,
) -> impl Future<Output = Self::Response> + Send;
@@ -39,6 +41,7 @@ trait Message<A>: Send + Any {
fn conflicts_with(&self) -> Arc<ConflictFn<A>>;
fn handle_with<'a>(
self: Box<Self>,
id: Guid,
actor: &'a mut A,
jobs: &'a BackgroundJobQueue,
) -> BoxFuture<'a, Box<dyn Any + Send>>;
@@ -52,10 +55,11 @@ where
}
fn handle_with<'a>(
self: Box<Self>,
id: Guid,
actor: &'a mut A,
jobs: &'a BackgroundJobQueue,
) -> BoxFuture<'a, Box<dyn Any + Send>> {
async move { Box::new(actor.handle(*self, jobs).await) as Box<dyn Any + Send> }.boxed()
async move { Box::new(actor.handle(id, *self, jobs).await) as Box<dyn Any + Send> }.boxed()
}
}
impl<A: Actor> dyn Message<A> {
@@ -80,7 +84,11 @@ impl<A: Actor> dyn Message<A> {
}
}
type Request<A> = (Box<dyn Message<A>>, oneshot::Sender<Box<dyn Any + Send>>);
type Request<A> = (
Guid,
Box<dyn Message<A>>,
oneshot::Sender<Box<dyn Any + Send>>,
);
pub enum PendingMessageStrategy {
CancelAll,

View File

@@ -7,6 +7,7 @@ use tokio::sync::oneshot::error::TryRecvError;
use tokio::sync::{mpsc, oneshot};
use crate::prelude::*;
use crate::rpc_continuations::Guid;
use crate::util::actor::background::BackgroundJobQueue;
use crate::util::actor::{Actor, Handler, PendingMessageStrategy, Request};
@@ -26,9 +27,9 @@ impl<A: Actor> SimpleActor<A> {
tokio::select! {
_ = &mut runner => (),
msg = messenger_recv.recv() => match msg {
Some((msg, reply)) if shutdown_recv.try_recv() == Err(TryRecvError::Empty) => {
Some((id, msg, reply)) if shutdown_recv.try_recv() == Err(TryRecvError::Empty) => {
tokio::select! {
res = msg.handle_with(&mut actor, &queue) => { let _ = reply.send(res); },
res = msg.handle_with(id, &mut actor, &queue) => { let _ = reply.send(res); },
_ = &mut runner => (),
}
}
@@ -60,7 +61,7 @@ impl<A: Actor> SimpleActor<A> {
}
let (reply_send, reply_recv) = oneshot::channel();
self.messenger
.send((Box::new(message), reply_send))
.send((Guid::new(), Box::new(message), reply_send))
.unwrap();
futures::future::Either::Right(
reply_recv

View File

@@ -1,4 +1,4 @@
use std::collections::{BTreeSet, VecDeque};
use std::collections::VecDeque;
use std::future::Future;
use std::io::Cursor;
use std::os::unix::prelude::MetadataExt;
@@ -274,6 +274,81 @@ pub fn response_to_reader(response: reqwest::Response) -> impl AsyncRead + Unpin
}))
}
#[pin_project::pin_project]
pub struct IOHook<'a, T> {
#[pin]
pub io: T,
pre_write: Option<Box<dyn FnMut(&[u8]) -> Result<(), std::io::Error> + Send + 'a>>,
post_write: Option<Box<dyn FnMut(&[u8]) + Send + 'a>>,
post_read: Option<Box<dyn FnMut(&[u8]) + Send + 'a>>,
}
impl<'a, T> IOHook<'a, T> {
pub fn new(io: T) -> Self {
Self {
io,
pre_write: None,
post_write: None,
post_read: None,
}
}
pub fn into_inner(self) -> T {
self.io
}
pub fn pre_write<F: FnMut(&[u8]) -> Result<(), std::io::Error> + Send + 'a>(&mut self, f: F) {
self.pre_write = Some(Box::new(f))
}
pub fn post_write<F: FnMut(&[u8]) + Send + 'a>(&mut self, f: F) {
self.post_write = Some(Box::new(f))
}
pub fn post_read<F: FnMut(&[u8]) + Send + 'a>(&mut self, f: F) {
self.post_read = Some(Box::new(f))
}
}
impl<'a, T: AsyncWrite> AsyncWrite for IOHook<'a, T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.project();
if let Some(pre_write) = this.pre_write {
pre_write(buf)?;
}
let written = futures::ready!(this.io.poll_write(cx, buf)?);
if let Some(post_write) = this.post_write {
post_write(&buf[..written]);
}
Poll::Ready(Ok(written))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
self.project().io.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
self.project().io.poll_shutdown(cx)
}
}
impl<'a, T: AsyncRead> AsyncRead for IOHook<'a, T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.project();
let start = buf.filled().len();
futures::ready!(this.io.poll_read(cx, buf)?);
if let Some(post_read) = this.post_read {
post_read(&buf.filled()[start..]);
}
Poll::Ready(Ok(()))
}
}
#[pin_project::pin_project]
pub struct BufferedWriteReader {
#[pin]
@@ -631,16 +706,16 @@ impl<S: AsyncRead + AsyncWrite> AsyncRead for TimeoutStream<S> {
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let mut this = self.project();
if let std::task::Poll::Ready(_) = this.sleep.as_mut().poll(cx) {
let timeout = this.sleep.as_mut().poll(cx);
let res = this.stream.poll_read(cx, buf);
if res.is_ready() {
this.sleep.reset(Instant::now() + *this.timeout);
} else if timeout.is_ready() {
return std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"timed out",
)));
}
let res = this.stream.poll_read(cx, buf);
if res.is_ready() {
this.sleep.reset(Instant::now() + *this.timeout);
}
res
}
}
@@ -650,10 +725,16 @@ impl<S: AsyncRead + AsyncWrite> AsyncWrite for TimeoutStream<S> {
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let this = self.project();
let mut this = self.project();
let timeout = this.sleep.as_mut().poll(cx);
let res = this.stream.poll_write(cx, buf);
if res.is_ready() {
this.sleep.reset(Instant::now() + *this.timeout);
} else if timeout.is_ready() {
return std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"timed out",
)));
}
res
}
@@ -661,10 +742,16 @@ impl<S: AsyncRead + AsyncWrite> AsyncWrite for TimeoutStream<S> {
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
let mut this = self.project();
let timeout = this.sleep.as_mut().poll(cx);
let res = this.stream.poll_flush(cx);
if res.is_ready() {
this.sleep.reset(Instant::now() + *this.timeout);
} else if timeout.is_ready() {
return std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"timed out",
)));
}
res
}
@@ -672,17 +759,21 @@ impl<S: AsyncRead + AsyncWrite> AsyncWrite for TimeoutStream<S> {
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
let mut this = self.project();
let timeout = this.sleep.as_mut().poll(cx);
let res = this.stream.poll_shutdown(cx);
if res.is_ready() {
this.sleep.reset(Instant::now() + *this.timeout);
} else if timeout.is_ready() {
return std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"timed out",
)));
}
res
}
}
pub struct TmpFile {}
#[derive(Debug)]
pub struct TmpDir {
path: PathBuf,
@@ -707,6 +798,14 @@ impl TmpDir {
tokio::fs::remove_dir_all(&self.path).await?;
Ok(())
}
pub async fn gc(self: Arc<Self>) -> Result<(), Error> {
if let Ok(dir) = Arc::try_unwrap(self) {
dir.delete().await
} else {
Ok(())
}
}
}
impl std::ops::Deref for TmpDir {
type Target = Path;
@@ -762,7 +861,7 @@ fn poll_flush_prefix<W: AsyncWrite>(
flush_writer: bool,
) -> Poll<Result<(), std::io::Error>> {
while let Some(mut cur) = prefix.pop_front() {
let buf = cur.remaining_slice();
let buf = CursorExt::remaining_slice(&cur);
if !buf.is_empty() {
match writer.as_mut().poll_write(cx, buf)? {
Poll::Ready(n) if n == buf.len() => (),

View File

@@ -1,4 +1,4 @@
use std::collections::BTreeMap;
use std::collections::{BTreeMap, VecDeque};
use std::future::Future;
use std::marker::PhantomData;
use std::path::{Path, PathBuf};
@@ -11,6 +11,8 @@ use std::time::Duration;
use async_trait::async_trait;
use color_eyre::eyre::{self, eyre};
use fd_lock_rs::FdLock;
use futures::future::BoxFuture;
use futures::FutureExt;
use helpers::canonicalize;
pub use helpers::NonDetachingJoinHandle;
use imbl_value::InternedString;
@@ -19,7 +21,8 @@ pub use models::VersionString;
use pin_project::pin_project;
use sha2::Digest;
use tokio::fs::File;
use tokio::sync::{Mutex, OwnedMutexGuard, RwLock};
use tokio::io::{AsyncRead, AsyncReadExt, BufReader};
use tokio::sync::{oneshot, Mutex, OwnedMutexGuard, RwLock};
use tracing::instrument;
use crate::shutdown::Shutdown;
@@ -33,6 +36,7 @@ pub mod http_reader;
pub mod io;
pub mod logger;
pub mod lshw;
pub mod net;
pub mod rpc;
pub mod rpc_client;
pub mod serde;
@@ -62,11 +66,16 @@ pub trait Invoke<'a> {
where
Self: 'ext,
'ext: 'a;
fn pipe<'ext: 'a>(
&'ext mut self,
next: &'ext mut tokio::process::Command,
) -> Self::Extended<'ext>;
fn timeout<'ext: 'a>(&'ext mut self, timeout: Option<Duration>) -> Self::Extended<'ext>;
fn input<'ext: 'a, Input: tokio::io::AsyncRead + Unpin + Send>(
&'ext mut self,
input: Option<&'ext mut Input>,
) -> Self::Extended<'ext>;
fn capture<'ext: 'a>(&'ext mut self, capture: bool) -> Self::Extended<'ext>;
fn invoke(
&mut self,
error_kind: crate::ErrorKind,
@@ -76,7 +85,20 @@ pub trait Invoke<'a> {
pub struct ExtendedCommand<'a> {
cmd: &'a mut tokio::process::Command,
timeout: Option<Duration>,
input: Option<&'a mut (dyn tokio::io::AsyncRead + Unpin + Send)>,
input: Option<&'a mut (dyn AsyncRead + Unpin + Send)>,
pipe: VecDeque<&'a mut tokio::process::Command>,
capture: bool,
}
impl<'a> From<&'a mut tokio::process::Command> for ExtendedCommand<'a> {
fn from(value: &'a mut tokio::process::Command) -> Self {
ExtendedCommand {
cmd: value,
timeout: None,
input: None,
pipe: VecDeque::new(),
capture: true,
}
}
}
impl<'a> std::ops::Deref for ExtendedCommand<'a> {
type Target = tokio::process::Command;
@@ -95,35 +117,38 @@ impl<'a> Invoke<'a> for tokio::process::Command {
where
Self: 'ext,
'ext: 'a;
fn timeout<'ext: 'a>(&'ext mut self, timeout: Option<Duration>) -> Self::Extended<'ext> {
ExtendedCommand {
cmd: self,
timeout,
input: None,
}
fn pipe<'ext: 'a>(
&'ext mut self,
next: &'ext mut tokio::process::Command,
) -> Self::Extended<'ext> {
let mut cmd = ExtendedCommand::from(self);
cmd.pipe.push_back(next);
cmd
}
fn input<'ext: 'a, Input: tokio::io::AsyncRead + Unpin + Send>(
fn timeout<'ext: 'a>(&'ext mut self, timeout: Option<Duration>) -> Self::Extended<'ext> {
let mut cmd = ExtendedCommand::from(self);
cmd.timeout = timeout;
cmd
}
fn input<'ext: 'a, Input: AsyncRead + Unpin + Send>(
&'ext mut self,
input: Option<&'ext mut Input>,
) -> Self::Extended<'ext> {
ExtendedCommand {
cmd: self,
timeout: None,
input: if let Some(input) = input {
Some(&mut *input)
} else {
None
},
}
let mut cmd = ExtendedCommand::from(self);
cmd.input = if let Some(input) = input {
Some(&mut *input)
} else {
None
};
cmd
}
fn capture<'ext: 'a>(&'ext mut self, capture: bool) -> Self::Extended<'ext> {
let mut cmd = ExtendedCommand::from(self);
cmd.capture = capture;
cmd
}
async fn invoke(&mut self, error_kind: crate::ErrorKind) -> Result<Vec<u8>, Error> {
ExtendedCommand {
cmd: self,
timeout: None,
input: None,
}
.invoke(error_kind)
.await
ExtendedCommand::from(self).invoke(error_kind).await
}
}
@@ -132,6 +157,13 @@ impl<'a> Invoke<'a> for ExtendedCommand<'a> {
where
Self: 'ext,
'ext: 'a;
fn pipe<'ext: 'a>(
&'ext mut self,
next: &'ext mut tokio::process::Command,
) -> Self::Extended<'ext> {
self.pipe.push_back(next.kill_on_drop(true));
self
}
fn timeout<'ext: 'a>(&'ext mut self, timeout: Option<Duration>) -> Self::Extended<'ext> {
self.timeout = timeout;
self
@@ -147,39 +179,150 @@ impl<'a> Invoke<'a> for ExtendedCommand<'a> {
};
self
}
fn capture<'ext: 'a>(&'ext mut self, capture: bool) -> Self::Extended<'ext> {
self.capture = capture;
self
}
#[instrument(skip_all)]
async fn invoke(&mut self, error_kind: crate::ErrorKind) -> Result<Vec<u8>, Error> {
let cmd_str = self
.cmd
.as_std()
.get_program()
.to_string_lossy()
.into_owned();
self.cmd.kill_on_drop(true);
if self.input.is_some() {
self.cmd.stdin(Stdio::piped());
}
self.cmd.stdout(Stdio::piped());
self.cmd.stderr(Stdio::piped());
let mut child = self.cmd.spawn().with_kind(error_kind)?;
if let (Some(mut stdin), Some(input)) = (child.stdin.take(), self.input.take()) {
use tokio::io::AsyncWriteExt;
tokio::io::copy(input, &mut stdin).await?;
stdin.flush().await?;
stdin.shutdown().await?;
drop(stdin);
if self.pipe.is_empty() {
if self.capture {
self.cmd.stdout(Stdio::piped());
self.cmd.stderr(Stdio::piped());
}
let mut child = self.cmd.spawn().with_ctx(|_| (error_kind, &cmd_str))?;
if let (Some(mut stdin), Some(input)) = (child.stdin.take(), self.input.take()) {
use tokio::io::AsyncWriteExt;
tokio::io::copy(input, &mut stdin).await?;
stdin.flush().await?;
stdin.shutdown().await?;
drop(stdin);
}
let res = match self.timeout {
None => child
.wait_with_output()
.await
.with_ctx(|_| (error_kind, &cmd_str))?,
Some(t) => tokio::time::timeout(t, child.wait_with_output())
.await
.with_kind(ErrorKind::Timeout)?
.with_ctx(|_| (error_kind, &cmd_str))?,
};
crate::ensure_code!(
res.status.success(),
error_kind,
"{}",
Some(&res.stderr)
.filter(|a| !a.is_empty())
.or(Some(&res.stdout))
.filter(|a| !a.is_empty())
.and_then(|a| std::str::from_utf8(a).ok())
.unwrap_or(&format!(
"{} exited with code {}",
self.cmd.as_std().get_program().to_string_lossy(),
res.status
))
);
Ok(res.stdout)
} else {
let mut futures = Vec::<BoxFuture<'_, Result<(), Error>>>::new(); // todo: predict capacity
let mut cmds = std::mem::take(&mut self.pipe);
cmds.push_front(&mut *self.cmd);
let len = cmds.len();
let timeout = self.timeout;
let mut prev = self
.input
.take()
.map(|i| Box::new(i) as Box<dyn AsyncRead + Unpin + Send>);
for (idx, cmd) in IntoIterator::into_iter(cmds).enumerate() {
let last = idx == len - 1;
if self.capture || !last {
cmd.stdout(Stdio::piped());
}
if self.capture {
cmd.stderr(Stdio::piped());
}
if prev.is_some() {
cmd.stdin(Stdio::piped());
}
let mut child = cmd.spawn().with_kind(error_kind)?;
let input = std::mem::replace(
&mut prev,
child
.stdout
.take()
.map(|i| Box::new(BufReader::new(i)) as Box<dyn AsyncRead + Unpin + Send>),
);
futures.push(
async move {
if let (Some(mut stdin), Some(mut input)) = (child.stdin.take(), input) {
use tokio::io::AsyncWriteExt;
tokio::io::copy(&mut input, &mut stdin).await?;
stdin.flush().await?;
stdin.shutdown().await?;
drop(stdin);
}
let res = match timeout {
None => child.wait_with_output().await?,
Some(t) => tokio::time::timeout(t, child.wait_with_output())
.await
.with_kind(ErrorKind::Timeout)??,
};
crate::ensure_code!(
res.status.success(),
error_kind,
"{}",
Some(&res.stderr)
.filter(|a| !a.is_empty())
.or(Some(&res.stdout))
.filter(|a| !a.is_empty())
.and_then(|a| std::str::from_utf8(a).ok())
.unwrap_or(&format!(
"{} exited with code {}",
cmd.as_std().get_program().to_string_lossy(),
res.status
))
);
Ok(())
}
.boxed(),
);
}
let (send, recv) = oneshot::channel();
futures.push(
async move {
if let Some(mut prev) = prev {
let mut res = Vec::new();
prev.read_to_end(&mut res).await?;
send.send(res).unwrap();
} else {
send.send(Vec::new()).unwrap();
}
Ok(())
}
.boxed(),
);
futures::future::try_join_all(futures).await?;
Ok(recv.await.unwrap())
}
let res = match self.timeout {
None => child.wait_with_output().await?,
Some(t) => tokio::time::timeout(t, child.wait_with_output())
.await
.with_kind(ErrorKind::Timeout)??,
};
crate::ensure_code!(
res.status.success(),
error_kind,
"{}",
Some(&res.stderr)
.filter(|a| !a.is_empty())
.or(Some(&res.stdout))
.filter(|a| !a.is_empty())
.and_then(|a| std::str::from_utf8(a).ok())
.unwrap_or(&format!("Unknown Error ({})", res.status))
);
Ok(res.stdout)
}
}

View File

@@ -0,0 +1,24 @@
use std::borrow::Cow;
use axum::extract::ws::{self, CloseFrame};
use futures::Future;
use crate::prelude::*;
pub trait WebSocketExt {
fn normal_close(
self,
msg: impl Into<Cow<'static, str>>,
) -> impl Future<Output = Result<(), Error>>;
}
impl WebSocketExt for ws::WebSocket {
async fn normal_close(mut self, msg: impl Into<Cow<'static, str>>) -> Result<(), Error> {
self.send(ws::Message::Close(Some(CloseFrame {
code: 1000,
reason: msg.into(),
})))
.await
.with_kind(ErrorKind::Network)
}
}

View File

@@ -22,8 +22,8 @@ use ts_rs::TS;
use super::IntoDoubleEndedIterator;
use crate::prelude::*;
use crate::util::Apply;
use crate::util::clap::FromStrParser;
use crate::util::Apply;
pub fn deserialize_from_str<
'de,
@@ -1040,15 +1040,19 @@ impl<T: AsRef<[u8]>> std::fmt::Display for Base64<T> {
f.write_str(&base64::encode(self.0.as_ref()))
}
}
impl<T: TryFrom<Vec<u8>>> FromStr for Base64<T>
{
impl<T: TryFrom<Vec<u8>>> FromStr for Base64<T> {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
base64::decode(&s)
.with_kind(ErrorKind::Deserialization)?
.apply(TryFrom::try_from)
.map(Self)
.map_err(|_| Error::new(eyre!("failed to create from buffer"), ErrorKind::Deserialization))
.map_err(|_| {
Error::new(
eyre!("failed to create from buffer"),
ErrorKind::Deserialization,
)
})
}
}
impl<'de, T: TryFrom<Vec<u8>>> Deserialize<'de> for Base64<T> {

View File

@@ -7,6 +7,7 @@ use imbl_value::InternedString;
use crate::db::model::Database;
use crate::prelude::*;
use crate::progress::PhaseProgressTrackerHandle;
use crate::Error;
mod v0_3_5;
@@ -85,11 +86,12 @@ where
&self,
version: &V,
db: &TypedPatchDb<Database>,
progress: &mut PhaseProgressTrackerHandle,
) -> impl Future<Output = Result<(), Error>> + Send {
async {
match self.semver().cmp(&version.semver()) {
Ordering::Greater => self.rollback_to_unchecked(version, db).await,
Ordering::Less => version.migrate_from_unchecked(self, db).await,
Ordering::Greater => self.rollback_to_unchecked(version, db, progress).await,
Ordering::Less => version.migrate_from_unchecked(self, db, progress).await,
Ordering::Equal => Ok(()),
}
}
@@ -98,11 +100,15 @@ where
&'a self,
version: &'a V,
db: &'a TypedPatchDb<Database>,
progress: &'a mut PhaseProgressTrackerHandle,
) -> BoxFuture<'a, Result<(), Error>> {
progress.add_total(1);
async {
let previous = Self::Previous::new();
if version.semver() < previous.semver() {
previous.migrate_from_unchecked(version, db).await?;
previous
.migrate_from_unchecked(version, db, progress)
.await?;
} else if version.semver() > previous.semver() {
return Err(Error::new(
eyre!(
@@ -115,6 +121,7 @@ where
tracing::info!("{} -> {}", previous.semver(), self.semver(),);
self.up(db).await?;
self.commit(db).await?;
*progress += 1;
Ok(())
}
.boxed()
@@ -123,14 +130,18 @@ where
&'a self,
version: &'a V,
db: &'a TypedPatchDb<Database>,
progress: &'a mut PhaseProgressTrackerHandle,
) -> BoxFuture<'a, Result<(), Error>> {
async {
let previous = Self::Previous::new();
tracing::info!("{} -> {}", self.semver(), previous.semver(),);
self.down(db).await?;
previous.commit(db).await?;
*progress += 1;
if version.semver() < previous.semver() {
previous.rollback_to_unchecked(version, db).await?;
previous
.rollback_to_unchecked(version, db, progress)
.await?;
} else if version.semver() > previous.semver() {
return Err(Error::new(
eyre!(
@@ -196,7 +207,11 @@ where
}
}
pub async fn init(db: &TypedPatchDb<Database>) -> Result<(), Error> {
pub async fn init(
db: &TypedPatchDb<Database>,
mut progress: PhaseProgressTrackerHandle,
) -> Result<(), Error> {
progress.start();
let version = Version::from_util_version(
db.peek()
.await
@@ -213,10 +228,10 @@ pub async fn init(db: &TypedPatchDb<Database>) -> Result<(), Error> {
ErrorKind::MigrationFailed,
));
}
Version::V0_3_5(v) => v.0.migrate_to(&Current::new(), &db).await?,
Version::V0_3_5_1(v) => v.0.migrate_to(&Current::new(), &db).await?,
Version::V0_3_5_2(v) => v.0.migrate_to(&Current::new(), &db).await?,
Version::V0_3_6(v) => v.0.migrate_to(&Current::new(), &db).await?,
Version::V0_3_5(v) => v.0.migrate_to(&Current::new(), &db, &mut progress).await?,
Version::V0_3_5_1(v) => v.0.migrate_to(&Current::new(), &db, &mut progress).await?,
Version::V0_3_5_2(v) => v.0.migrate_to(&Current::new(), &db, &mut progress).await?,
Version::V0_3_6(v) => v.0.migrate_to(&Current::new(), &db, &mut progress).await?,
Version::Other(_) => {
return Err(Error::new(
eyre!("Cannot downgrade"),
@@ -224,6 +239,7 @@ pub async fn init(db: &TypedPatchDb<Database>) -> Result<(), Error> {
))
}
}
progress.complete();
Ok(())
}

View File

@@ -20,7 +20,11 @@ pub fn data_dir<P: AsRef<Path>>(datadir: P, pkg_id: &PackageId, volume_id: &Volu
.join(volume_id)
}
pub fn asset_dir<P: AsRef<Path>>(datadir: P, pkg_id: &PackageId, version: &VersionString) -> PathBuf {
pub fn asset_dir<P: AsRef<Path>>(
datadir: P,
pkg_id: &PackageId,
version: &VersionString,
) -> PathBuf {
datadir
.as_ref()
.join(PKG_VOLUME_DIR)