Merge branch 'next/patch' of github.com:Start9Labs/start-os into next/minor

This commit is contained in:
Aiden McClelland
2023-11-13 14:28:26 -07:00
990 changed files with 2016 additions and 6686 deletions

132
core/startos/src/account.rs Normal file
View File

@@ -0,0 +1,132 @@
use std::time::SystemTime;
use ed25519_dalek::SecretKey;
use openssl::pkey::{PKey, Private};
use openssl::x509::X509;
use sqlx::PgExecutor;
use crate::hostname::{generate_hostname, generate_id, Hostname};
use crate::net::keys::Key;
use crate::net::ssl::{generate_key, make_root_cert};
use crate::prelude::*;
use crate::util::crypto::ed25519_expand_key;
fn hash_password(password: &str) -> Result<String, Error> {
argon2::hash_encoded(
password.as_bytes(),
&rand::random::<[u8; 16]>()[..],
&argon2::Config::rfc9106_low_mem(),
)
.with_kind(crate::ErrorKind::PasswordHashGeneration)
}
#[derive(Debug, Clone)]
pub struct AccountInfo {
pub server_id: String,
pub hostname: Hostname,
pub password: String,
pub key: Key,
pub root_ca_key: PKey<Private>,
pub root_ca_cert: X509,
}
impl AccountInfo {
pub fn new(password: &str, start_time: SystemTime) -> Result<Self, Error> {
let server_id = generate_id();
let hostname = generate_hostname();
let root_ca_key = generate_key()?;
let root_ca_cert = make_root_cert(&root_ca_key, &hostname, start_time)?;
Ok(Self {
server_id,
hostname,
password: hash_password(password)?,
key: Key::new(None),
root_ca_key,
root_ca_cert,
})
}
pub async fn load(secrets: impl PgExecutor<'_>) -> Result<Self, Error> {
let r = sqlx::query!("SELECT * FROM account WHERE id = 0")
.fetch_one(secrets)
.await?;
let server_id = r.server_id.unwrap_or_else(generate_id);
let hostname = r.hostname.map(Hostname).unwrap_or_else(generate_hostname);
let password = r.password;
let network_key = SecretKey::try_from(r.network_key).map_err(|e| {
Error::new(
eyre!("expected vec of len 32, got len {}", e.len()),
ErrorKind::ParseDbField,
)
})?;
let tor_key = if let Some(k) = &r.tor_key {
<[u8; 64]>::try_from(&k[..]).map_err(|_| {
Error::new(
eyre!("expected vec of len 64, got len {}", k.len()),
ErrorKind::ParseDbField,
)
})?
} else {
ed25519_expand_key(&network_key)
};
let key = Key::from_pair(None, network_key, tor_key);
let root_ca_key = PKey::private_key_from_pem(r.root_ca_key_pem.as_bytes())?;
let root_ca_cert = X509::from_pem(r.root_ca_cert_pem.as_bytes())?;
Ok(Self {
server_id,
hostname,
password,
key,
root_ca_key,
root_ca_cert,
})
}
pub async fn save(&self, secrets: impl PgExecutor<'_>) -> Result<(), Error> {
let server_id = self.server_id.as_str();
let hostname = self.hostname.0.as_str();
let password = self.password.as_str();
let network_key = self.key.as_bytes();
let network_key = network_key.as_slice();
let root_ca_key = String::from_utf8(self.root_ca_key.private_key_to_pem_pkcs8()?)?;
let root_ca_cert = String::from_utf8(self.root_ca_cert.to_pem()?)?;
sqlx::query!(
r#"
INSERT INTO account (
id,
server_id,
hostname,
password,
network_key,
root_ca_key_pem,
root_ca_cert_pem
) VALUES (
0, $1, $2, $3, $4, $5, $6
) ON CONFLICT (id) DO UPDATE SET
server_id = EXCLUDED.server_id,
hostname = EXCLUDED.hostname,
password = EXCLUDED.password,
network_key = EXCLUDED.network_key,
root_ca_key_pem = EXCLUDED.root_ca_key_pem,
root_ca_cert_pem = EXCLUDED.root_ca_cert_pem
"#,
server_id,
hostname,
password,
network_key,
root_ca_key,
root_ca_cert,
)
.execute(secrets)
.await?;
Ok(())
}
pub fn set_password(&mut self, password: &str) -> Result<(), Error> {
self.password = hash_password(password)?;
Ok(())
}
}

163
core/startos/src/action.rs Normal file
View File

@@ -0,0 +1,163 @@
use std::collections::{BTreeMap, BTreeSet};
use clap::ArgMatches;
use color_eyre::eyre::eyre;
use indexmap::IndexSet;
pub use models::ActionId;
use models::ImageId;
use rpc_toolkit::command;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use crate::config::{Config, ConfigSpec};
use crate::context::RpcContext;
use crate::prelude::*;
use crate::procedure::docker::DockerContainers;
use crate::procedure::{PackageProcedure, ProcedureName};
use crate::s9pk::manifest::PackageId;
use crate::util::serde::{display_serializable, parse_stdin_deserializable, IoFormat};
use crate::util::Version;
use crate::volume::Volumes;
use crate::{Error, ResultExt};
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct Actions(pub BTreeMap<ActionId, Action>);
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "version")]
pub enum ActionResult {
#[serde(rename = "0")]
V0(ActionResultV0),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ActionResultV0 {
pub message: String,
pub value: Option<String>,
pub copyable: bool,
pub qr: bool,
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub enum DockerStatus {
Running,
Stopped,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct Action {
pub name: String,
pub description: String,
#[serde(default)]
pub warning: Option<String>,
pub implementation: PackageProcedure,
pub allowed_statuses: IndexSet<DockerStatus>,
#[serde(default)]
pub input_spec: ConfigSpec,
}
impl Action {
#[instrument(skip_all)]
pub fn validate(
&self,
_container: &Option<DockerContainers>,
eos_version: &Version,
volumes: &Volumes,
image_ids: &BTreeSet<ImageId>,
) -> Result<(), Error> {
self.implementation
.validate(eos_version, volumes, image_ids, true)
.with_ctx(|_| {
(
crate::ErrorKind::ValidateS9pk,
format!("Action {}", self.name),
)
})
}
#[instrument(skip_all)]
pub async fn execute(
&self,
ctx: &RpcContext,
pkg_id: &PackageId,
pkg_version: &Version,
action_id: &ActionId,
volumes: &Volumes,
input: Option<Config>,
) -> Result<ActionResult, Error> {
if let Some(ref input) = input {
self.input_spec
.matches(&input)
.with_kind(crate::ErrorKind::ConfigSpecViolation)?;
}
self.implementation
.execute(
ctx,
pkg_id,
pkg_version,
ProcedureName::Action(action_id.clone()),
volumes,
input,
None,
)
.await?
.map_err(|e| Error::new(eyre!("{}", e.1), crate::ErrorKind::Action))
}
}
fn display_action_result(action_result: ActionResult, matches: &ArgMatches) {
if matches.is_present("format") {
return display_serializable(action_result, matches);
}
match action_result {
ActionResult::V0(ar) => {
println!(
"{}: {}",
ar.message,
serde_json::to_string(&ar.value).unwrap()
);
}
}
}
#[command(about = "Executes an action", display(display_action_result))]
#[instrument(skip_all)]
pub async fn action(
#[context] ctx: RpcContext,
#[arg(rename = "id")] pkg_id: PackageId,
#[arg(rename = "action-id")] action_id: ActionId,
#[arg(stdin, parse(parse_stdin_deserializable))] input: Option<Config>,
#[allow(unused_variables)]
#[arg(long = "format")]
format: Option<IoFormat>,
) -> Result<ActionResult, Error> {
let manifest = ctx
.db
.peek()
.await
.as_package_data()
.as_idx(&pkg_id)
.or_not_found(&pkg_id)?
.as_installed()
.or_not_found(&pkg_id)?
.as_manifest()
.de()?;
if let Some(action) = manifest.actions.0.get(&action_id) {
action
.execute(
&ctx,
&manifest.id,
&manifest.version,
&action_id,
&manifest.volumes,
input,
)
.await
} else {
Err(Error::new(
eyre!("Action not found in manifest"),
crate::ErrorKind::NotFound,
))
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

391
core/startos/src/auth.rs Normal file
View File

@@ -0,0 +1,391 @@
use std::collections::BTreeMap;
use std::marker::PhantomData;
use chrono::{DateTime, Utc};
use clap::ArgMatches;
use color_eyre::eyre::eyre;
use josekit::jwk::Jwk;
use rpc_toolkit::command;
use rpc_toolkit::command_helpers::prelude::{RequestParts, ResponseParts};
use rpc_toolkit::yajrc::RpcError;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sqlx::{Executor, Postgres};
use tracing::instrument;
use crate::context::{CliContext, RpcContext};
use crate::middleware::auth::{AsLogoutSessionId, HasLoggedOutSessions, HashSessionToken};
use crate::middleware::encrypt::EncryptedWire;
use crate::prelude::*;
use crate::util::display_none;
use crate::util::serde::{display_serializable, IoFormat};
use crate::{ensure_code, Error, ResultExt};
#[derive(Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PasswordType {
EncryptedWire(EncryptedWire),
String(String),
}
impl PasswordType {
pub fn decrypt(self, current_secret: impl AsRef<Jwk>) -> Result<String, Error> {
match self {
PasswordType::String(x) => Ok(x),
PasswordType::EncryptedWire(x) => x.decrypt(current_secret).ok_or_else(|| {
Error::new(
color_eyre::eyre::eyre!("Couldn't decode password"),
crate::ErrorKind::Unknown,
)
}),
}
}
}
impl Default for PasswordType {
fn default() -> Self {
PasswordType::String(String::default())
}
}
impl std::fmt::Debug for PasswordType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<REDACTED_PASSWORD>")?;
Ok(())
}
}
impl std::str::FromStr for PasswordType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match serde_json::from_str(s) {
Ok(a) => a,
Err(_) => PasswordType::String(s.to_string()),
})
}
}
#[command(subcommands(login, logout, session, reset_password, get_pubkey))]
pub fn auth() -> Result<(), Error> {
Ok(())
}
pub fn cli_metadata() -> Value {
serde_json::json!({
"platforms": ["cli"],
})
}
pub fn parse_metadata(_: &str, _: &ArgMatches) -> Result<Value, Error> {
Ok(cli_metadata())
}
#[test]
fn gen_pwd() {
println!(
"{:?}",
argon2::hash_encoded(
b"testing1234",
&rand::random::<[u8; 16]>()[..],
&argon2::Config::rfc9106_low_mem()
)
.unwrap()
)
}
#[instrument(skip_all)]
async fn cli_login(
ctx: CliContext,
password: Option<PasswordType>,
metadata: Value,
) -> Result<(), RpcError> {
let password = if let Some(password) = password {
password.decrypt(&ctx)?
} else {
rpassword::prompt_password("Password: ")?
};
rpc_toolkit::command_helpers::call_remote(
ctx,
"auth.login",
serde_json::json!({ "password": password, "metadata": metadata }),
PhantomData::<()>,
)
.await?
.result?;
Ok(())
}
pub fn check_password(hash: &str, password: &str) -> Result<(), Error> {
ensure_code!(
argon2::verify_encoded(&hash, password.as_bytes()).map_err(|_| {
Error::new(
eyre!("Password Incorrect"),
crate::ErrorKind::IncorrectPassword,
)
})?,
crate::ErrorKind::IncorrectPassword,
"Password Incorrect"
);
Ok(())
}
pub async fn check_password_against_db<Ex>(secrets: &mut Ex, password: &str) -> Result<(), Error>
where
for<'a> &'a mut Ex: Executor<'a, Database = Postgres>,
{
let pw_hash = sqlx::query!("SELECT password FROM account")
.fetch_one(secrets)
.await?
.password;
check_password(&pw_hash, password)?;
Ok(())
}
#[command(
custom_cli(cli_login(async, context(CliContext))),
display(display_none),
metadata(authenticated = false)
)]
#[instrument(skip_all)]
pub async fn login(
#[context] ctx: RpcContext,
#[request] req: &RequestParts,
#[response] res: &mut ResponseParts,
#[arg] password: Option<PasswordType>,
#[arg(
parse(parse_metadata),
default = "cli_metadata",
help = "RPC Only: This value cannot be overidden from the cli"
)]
metadata: Value,
) -> Result<(), Error> {
let password = password.unwrap_or_default().decrypt(&ctx)?;
let mut handle = ctx.secret_store.acquire().await?;
check_password_against_db(handle.as_mut(), &password).await?;
let hash_token = HashSessionToken::new();
let user_agent = req.headers.get("user-agent").and_then(|h| h.to_str().ok());
let metadata = serde_json::to_string(&metadata).with_kind(crate::ErrorKind::Database)?;
let hash_token_hashed = hash_token.hashed();
sqlx::query!(
"INSERT INTO session (id, user_agent, metadata) VALUES ($1, $2, $3)",
hash_token_hashed,
user_agent,
metadata,
)
.execute(handle.as_mut())
.await?;
res.headers.insert(
"set-cookie",
hash_token.header_value()?, // Should be impossible, but don't want to panic
);
Ok(())
}
#[command(display(display_none), metadata(authenticated = false))]
#[instrument(skip_all)]
pub async fn logout(
#[context] ctx: RpcContext,
#[request] req: &RequestParts,
) -> Result<Option<HasLoggedOutSessions>, Error> {
let auth = match HashSessionToken::from_request_parts(req) {
Err(_) => return Ok(None),
Ok(a) => a,
};
Ok(Some(HasLoggedOutSessions::new(vec![auth], &ctx).await?))
}
#[derive(Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct Session {
logged_in: DateTime<Utc>,
last_active: DateTime<Utc>,
user_agent: Option<String>,
metadata: Value,
}
#[derive(Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct SessionList {
current: String,
sessions: BTreeMap<String, Session>,
}
#[command(subcommands(list, kill))]
pub async fn session() -> Result<(), Error> {
Ok(())
}
fn display_sessions(arg: SessionList, matches: &ArgMatches) {
use prettytable::*;
if matches.is_present("format") {
return display_serializable(arg, matches);
}
let mut table = Table::new();
table.add_row(row![bc =>
"ID",
"LOGGED IN",
"LAST ACTIVE",
"USER AGENT",
"METADATA",
]);
for (id, session) in arg.sessions {
let mut row = row![
&id,
&format!("{}", session.logged_in),
&format!("{}", session.last_active),
session.user_agent.as_deref().unwrap_or("N/A"),
&format!("{}", session.metadata),
];
if id == arg.current {
row.iter_mut()
.map(|c| c.style(Attr::ForegroundColor(color::GREEN)))
.collect::<()>()
}
table.add_row(row);
}
table.print_tty(false).unwrap();
}
#[command(display(display_sessions))]
#[instrument(skip_all)]
pub async fn list(
#[context] ctx: RpcContext,
#[request] req: &RequestParts,
#[allow(unused_variables)]
#[arg(long = "format")]
format: Option<IoFormat>,
) -> Result<SessionList, Error> {
Ok(SessionList {
current: HashSessionToken::from_request_parts(req)?.as_hash(),
sessions: sqlx::query!(
"SELECT * FROM session WHERE logged_out IS NULL OR logged_out > CURRENT_TIMESTAMP"
)
.fetch_all(ctx.secret_store.acquire().await?.as_mut())
.await?
.into_iter()
.map(|row| {
Ok((
row.id,
Session {
logged_in: DateTime::from_utc(row.logged_in, Utc),
last_active: DateTime::from_utc(row.last_active, Utc),
user_agent: row.user_agent,
metadata: serde_json::from_str(&row.metadata)
.with_kind(crate::ErrorKind::Database)?,
},
))
})
.collect::<Result<_, Error>>()?,
})
}
fn parse_comma_separated(arg: &str, _: &ArgMatches) -> Result<Vec<String>, RpcError> {
Ok(arg.split(",").map(|s| s.trim().to_owned()).collect())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct KillSessionId(String);
impl AsLogoutSessionId for KillSessionId {
fn as_logout_session_id(self) -> String {
self.0
}
}
#[command(display(display_none))]
#[instrument(skip_all)]
pub async fn kill(
#[context] ctx: RpcContext,
#[arg(parse(parse_comma_separated))] ids: Vec<String>,
) -> Result<(), Error> {
HasLoggedOutSessions::new(ids.into_iter().map(KillSessionId), &ctx).await?;
Ok(())
}
#[instrument(skip_all)]
async fn cli_reset_password(
ctx: CliContext,
old_password: Option<PasswordType>,
new_password: Option<PasswordType>,
) -> Result<(), RpcError> {
let old_password = if let Some(old_password) = old_password {
old_password.decrypt(&ctx)?
} else {
rpassword::prompt_password("Current Password: ")?
};
let new_password = if let Some(new_password) = new_password {
new_password.decrypt(&ctx)?
} else {
let new_password = rpassword::prompt_password("New Password: ")?;
if new_password != rpassword::prompt_password("Confirm: ")? {
return Err(Error::new(
eyre!("Passwords do not match"),
crate::ErrorKind::IncorrectPassword,
)
.into());
}
new_password
};
rpc_toolkit::command_helpers::call_remote(
ctx,
"auth.reset-password",
serde_json::json!({ "old-password": old_password, "new-password": new_password }),
PhantomData::<()>,
)
.await?
.result?;
Ok(())
}
#[command(
rename = "reset-password",
custom_cli(cli_reset_password(async, context(CliContext))),
display(display_none)
)]
#[instrument(skip_all)]
pub async fn reset_password(
#[context] ctx: RpcContext,
#[arg(rename = "old-password")] old_password: Option<PasswordType>,
#[arg(rename = "new-password")] new_password: Option<PasswordType>,
) -> Result<(), Error> {
let old_password = old_password.unwrap_or_default().decrypt(&ctx)?;
let new_password = new_password.unwrap_or_default().decrypt(&ctx)?;
let mut account = ctx.account.write().await;
if !argon2::verify_encoded(&account.password, old_password.as_bytes())
.with_kind(crate::ErrorKind::IncorrectPassword)?
{
return Err(Error::new(
eyre!("Incorrect Password"),
crate::ErrorKind::IncorrectPassword,
));
}
account.set_password(&new_password)?;
account.save(&ctx.secret_store).await?;
let account_password = &account.password;
ctx.db
.mutate(|d| {
d.as_server_info_mut()
.as_password_hash_mut()
.ser(account_password)
})
.await
}
#[command(
rename = "get-pubkey",
display(display_none),
metadata(authenticated = false)
)]
#[instrument(skip_all)]
pub async fn get_pubkey(#[context] ctx: RpcContext) -> Result<Jwk, RpcError> {
let secret = ctx.as_ref().clone();
let pub_key = secret.to_public_key()?;
Ok(pub_key)
}

View File

@@ -0,0 +1,322 @@
use std::collections::BTreeMap;
use std::panic::UnwindSafe;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use chrono::Utc;
use clap::ArgMatches;
use color_eyre::eyre::eyre;
use helpers::AtomicFile;
use imbl::OrdSet;
use models::Version;
use rpc_toolkit::command;
use tokio::io::AsyncWriteExt;
use tokio::sync::Mutex;
use tracing::instrument;
use super::target::BackupTargetId;
use super::PackageBackupReport;
use crate::auth::check_password_against_db;
use crate::backup::os::OsBackup;
use crate::backup::{BackupReport, ServerBackupReport};
use crate::context::RpcContext;
use crate::db::model::BackupProgress;
use crate::db::package::get_packages;
use crate::disk::mount::backup::BackupMountGuard;
use crate::disk::mount::filesystem::ReadWrite;
use crate::disk::mount::guard::TmpMountGuard;
use crate::manager::BackupReturn;
use crate::notifications::NotificationLevel;
use crate::prelude::*;
use crate::s9pk::manifest::PackageId;
use crate::util::display_none;
use crate::util::io::dir_copy;
use crate::util::serde::IoFormat;
use crate::version::VersionT;
fn parse_comma_separated(arg: &str, _: &ArgMatches) -> Result<OrdSet<PackageId>, Error> {
arg.split(',')
.map(|s| s.trim().parse::<PackageId>().map_err(Error::from))
.collect()
}
#[command(rename = "create", display(display_none))]
#[instrument(skip(ctx, old_password, password))]
pub async fn backup_all(
#[context] ctx: RpcContext,
#[arg(rename = "target-id")] target_id: BackupTargetId,
#[arg(rename = "old-password", long = "old-password")] old_password: Option<
crate::auth::PasswordType,
>,
#[arg(
rename = "package-ids",
long = "package-ids",
parse(parse_comma_separated)
)]
package_ids: Option<OrdSet<PackageId>>,
#[arg] password: crate::auth::PasswordType,
) -> Result<(), Error> {
let db = ctx.db.peek().await;
let old_password_decrypted = old_password
.as_ref()
.unwrap_or(&password)
.clone()
.decrypt(&ctx)?;
let password = password.decrypt(&ctx)?;
check_password_against_db(ctx.secret_store.acquire().await?.as_mut(), &password).await?;
let fs = target_id
.load(ctx.secret_store.acquire().await?.as_mut())
.await?;
let mut backup_guard = BackupMountGuard::mount(
TmpMountGuard::mount(&fs, ReadWrite).await?,
&old_password_decrypted,
)
.await?;
let package_ids = if let Some(ids) = package_ids {
ids.into_iter()
.flat_map(|package_id| {
let version = db
.as_package_data()
.as_idx(&package_id)?
.as_manifest()
.as_version()
.de()
.ok()?;
Some((package_id, version))
})
.collect()
} else {
get_packages(db.clone())?.into_iter().collect()
};
if old_password.is_some() {
backup_guard.change_password(&password)?;
}
assure_backing_up(&ctx.db, &package_ids).await?;
tokio::task::spawn(async move {
let backup_res = perform_backup(&ctx, backup_guard, &package_ids).await;
match backup_res {
Ok(report) if report.iter().all(|(_, rep)| rep.error.is_none()) => ctx
.notification_manager
.notify(
ctx.db.clone(),
None,
NotificationLevel::Success,
"Backup Complete".to_owned(),
"Your backup has completed".to_owned(),
BackupReport {
server: ServerBackupReport {
attempted: true,
error: None,
},
packages: report
.into_iter()
.map(|((package_id, _), value)| (package_id, value))
.collect(),
},
None,
)
.await
.expect("failed to send notification"),
Ok(report) => ctx
.notification_manager
.notify(
ctx.db.clone(),
None,
NotificationLevel::Warning,
"Backup Complete".to_owned(),
"Your backup has completed, but some package(s) failed to backup".to_owned(),
BackupReport {
server: ServerBackupReport {
attempted: true,
error: None,
},
packages: report
.into_iter()
.map(|((package_id, _), value)| (package_id, value))
.collect(),
},
None,
)
.await
.expect("failed to send notification"),
Err(e) => {
tracing::error!("Backup Failed: {}", e);
tracing::debug!("{:?}", e);
ctx.notification_manager
.notify(
ctx.db.clone(),
None,
NotificationLevel::Error,
"Backup Failed".to_owned(),
"Your backup failed to complete.".to_owned(),
BackupReport {
server: ServerBackupReport {
attempted: true,
error: Some(e.to_string()),
},
packages: BTreeMap::new(),
},
None,
)
.await
.expect("failed to send notification");
}
}
ctx.db
.mutate(|v| {
v.as_server_info_mut()
.as_status_info_mut()
.as_backup_progress_mut()
.ser(&None)
})
.await?;
Ok::<(), Error>(())
});
Ok(())
}
#[instrument(skip(db, packages))]
async fn assure_backing_up(
db: &PatchDb,
packages: impl IntoIterator<Item = &(PackageId, Version)> + UnwindSafe + Send,
) -> Result<(), Error> {
db.mutate(|v| {
let backing_up = v
.as_server_info_mut()
.as_status_info_mut()
.as_backup_progress_mut();
if backing_up
.clone()
.de()?
.iter()
.flat_map(|x| x.values())
.fold(false, |acc, x| {
if !x.complete {
return true;
}
acc
})
{
return Err(Error::new(
eyre!("Server is already backing up!"),
ErrorKind::InvalidRequest,
));
}
backing_up.ser(&Some(
packages
.into_iter()
.map(|(x, _)| (x.clone(), BackupProgress { complete: false }))
.collect(),
))?;
Ok(())
})
.await
}
#[instrument(skip(ctx, backup_guard))]
async fn perform_backup(
ctx: &RpcContext,
backup_guard: BackupMountGuard<TmpMountGuard>,
package_ids: &OrdSet<(PackageId, Version)>,
) -> Result<BTreeMap<(PackageId, Version), PackageBackupReport>, Error> {
let mut backup_report = BTreeMap::new();
let backup_guard = Arc::new(Mutex::new(backup_guard));
for package_id in package_ids {
let (response, _report) = match ctx
.managers
.get(package_id)
.await
.ok_or_else(|| Error::new(eyre!("Manager not found"), ErrorKind::InvalidRequest))?
.backup(backup_guard.clone())
.await
{
BackupReturn::Ran { report, res } => (res, report),
BackupReturn::AlreadyRunning(report) => {
backup_report.insert(package_id.clone(), report);
continue;
}
BackupReturn::Error(error) => {
tracing::warn!("Backup thread error");
tracing::debug!("{error:?}");
backup_report.insert(
package_id.clone(),
PackageBackupReport {
error: Some("Backup thread error".to_owned()),
},
);
continue;
}
};
backup_report.insert(
package_id.clone(),
PackageBackupReport {
error: response.as_ref().err().map(|e| e.to_string()),
},
);
if let Ok(pkg_meta) = response {
backup_guard
.lock()
.await
.metadata
.package_backups
.insert(package_id.0.clone(), pkg_meta);
}
}
let ui = ctx.db.peek().await.into_ui().de()?;
let mut os_backup_file = AtomicFile::new(
backup_guard.lock().await.as_ref().join("os-backup.cbor"),
None::<PathBuf>,
)
.await
.with_kind(ErrorKind::Filesystem)?;
os_backup_file
.write_all(&IoFormat::Cbor.to_vec(&OsBackup {
account: ctx.account.read().await.clone(),
ui,
})?)
.await?;
os_backup_file
.save()
.await
.with_kind(ErrorKind::Filesystem)?;
let luks_folder_old = backup_guard.lock().await.as_ref().join("luks.old");
if tokio::fs::metadata(&luks_folder_old).await.is_ok() {
tokio::fs::remove_dir_all(&luks_folder_old).await?;
}
let luks_folder_bak = backup_guard.lock().await.as_ref().join("luks");
if tokio::fs::metadata(&luks_folder_bak).await.is_ok() {
tokio::fs::rename(&luks_folder_bak, &luks_folder_old).await?;
}
let luks_folder = Path::new("/media/embassy/config/luks");
if tokio::fs::metadata(&luks_folder).await.is_ok() {
dir_copy(&luks_folder, &luks_folder_bak, None).await?;
}
let timestamp = Some(Utc::now());
let mut backup_guard = Arc::try_unwrap(backup_guard)
.map_err(|_err| {
Error::new(
eyre!("Backup guard could not ensure that the others where dropped"),
ErrorKind::Unknown,
)
})?
.into_inner();
backup_guard.unencrypted_metadata.version = crate::version::Current::new().semver().into();
backup_guard.unencrypted_metadata.full = true;
backup_guard.metadata.version = crate::version::Current::new().semver().into();
backup_guard.metadata.timestamp = timestamp;
backup_guard.save_and_unmount().await?;
ctx.db
.mutate(|v| v.as_server_info_mut().as_last_backup_mut().ser(&timestamp))
.await?;
Ok(backup_report)
}

View File

@@ -0,0 +1,226 @@
use std::collections::{BTreeMap, BTreeSet};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use chrono::{DateTime, Utc};
use color_eyre::eyre::eyre;
use helpers::AtomicFile;
use models::{ImageId, OptionExt};
use reqwest::Url;
use rpc_toolkit::command;
use serde::{Deserialize, Serialize};
use tokio::fs::File;
use tokio::io::AsyncWriteExt;
use tracing::instrument;
use self::target::PackageBackupInfo;
use crate::context::RpcContext;
use crate::install::PKG_ARCHIVE_DIR;
use crate::manager::manager_seed::ManagerSeed;
use crate::net::interface::InterfaceId;
use crate::net::keys::Key;
use crate::prelude::*;
use crate::procedure::docker::DockerContainers;
use crate::procedure::{NoOutput, PackageProcedure, ProcedureName};
use crate::s9pk::manifest::PackageId;
use crate::util::serde::{Base32, Base64, IoFormat};
use crate::util::Version;
use crate::version::{Current, VersionT};
use crate::volume::{backup_dir, Volume, VolumeId, Volumes, BACKUP_DIR};
use crate::{Error, ErrorKind, ResultExt};
pub mod backup_bulk;
pub mod os;
pub mod restore;
pub mod target;
#[derive(Debug, Deserialize, Serialize)]
pub struct BackupReport {
server: ServerBackupReport,
packages: BTreeMap<PackageId, PackageBackupReport>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ServerBackupReport {
attempted: bool,
error: Option<String>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct PackageBackupReport {
pub error: Option<String>,
}
#[command(subcommands(backup_bulk::backup_all, target::target))]
pub fn backup() -> Result<(), Error> {
Ok(())
}
#[command(rename = "backup", subcommands(restore::restore_packages_rpc))]
pub fn package_backup() -> Result<(), Error> {
Ok(())
}
#[derive(Deserialize, Serialize)]
struct BackupMetadata {
pub timestamp: DateTime<Utc>,
#[serde(default)]
pub network_keys: BTreeMap<InterfaceId, Base64<[u8; 32]>>,
#[serde(default)]
pub tor_keys: BTreeMap<InterfaceId, Base32<[u8; 64]>>, // DEPRECATED
pub marketplace_url: Option<Url>,
}
#[derive(Clone, Debug, Deserialize, Serialize, HasModel)]
#[model = "Model<Self>"]
pub struct BackupActions {
pub create: PackageProcedure,
pub restore: PackageProcedure,
}
impl BackupActions {
pub fn validate(
&self,
_container: &Option<DockerContainers>,
eos_version: &Version,
volumes: &Volumes,
image_ids: &BTreeSet<ImageId>,
) -> Result<(), Error> {
self.create
.validate(eos_version, volumes, image_ids, false)
.with_ctx(|_| (crate::ErrorKind::ValidateS9pk, "Backup Create"))?;
self.restore
.validate(eos_version, volumes, image_ids, false)
.with_ctx(|_| (crate::ErrorKind::ValidateS9pk, "Backup Restore"))?;
Ok(())
}
#[instrument(skip_all)]
pub async fn create(&self, seed: Arc<ManagerSeed>) -> Result<PackageBackupInfo, Error> {
let manifest = &seed.manifest;
let mut volumes = seed.manifest.volumes.to_readonly();
let ctx = &seed.ctx;
let pkg_id = &manifest.id;
let pkg_version = &manifest.version;
volumes.insert(VolumeId::Backup, Volume::Backup { readonly: false });
let backup_dir = backup_dir(&manifest.id);
if tokio::fs::metadata(&backup_dir).await.is_err() {
tokio::fs::create_dir_all(&backup_dir).await?
}
self.create
.execute::<(), NoOutput>(
ctx,
pkg_id,
pkg_version,
ProcedureName::CreateBackup,
&volumes,
None,
None,
)
.await?
.map_err(|e| eyre!("{}", e.1))
.with_kind(crate::ErrorKind::Backup)?;
let (network_keys, tor_keys): (Vec<_>, Vec<_>) =
Key::for_package(&ctx.secret_store, pkg_id)
.await?
.into_iter()
.filter_map(|k| {
let interface = k.interface().map(|(_, i)| i)?;
Some((
(interface.clone(), Base64(k.as_bytes())),
(interface, Base32(k.tor_key().as_bytes())),
))
})
.unzip();
let marketplace_url = ctx
.db
.peek()
.await
.as_package_data()
.as_idx(&pkg_id)
.or_not_found(pkg_id)?
.expect_as_installed()?
.as_installed()
.as_marketplace_url()
.de()?;
let tmp_path = Path::new(BACKUP_DIR)
.join(pkg_id)
.join(format!("{}.s9pk", pkg_id));
let s9pk_path = ctx
.datadir
.join(PKG_ARCHIVE_DIR)
.join(pkg_id)
.join(pkg_version.as_str())
.join(format!("{}.s9pk", pkg_id));
let mut infile = File::open(&s9pk_path).await?;
let mut outfile = AtomicFile::new(&tmp_path, None::<PathBuf>)
.await
.with_kind(ErrorKind::Filesystem)?;
tokio::io::copy(&mut infile, &mut *outfile)
.await
.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
format!("cp {} -> {}", s9pk_path.display(), tmp_path.display()),
)
})?;
outfile.save().await.with_kind(ErrorKind::Filesystem)?;
let timestamp = Utc::now();
let metadata_path = Path::new(BACKUP_DIR).join(pkg_id).join("metadata.cbor");
let mut outfile = AtomicFile::new(&metadata_path, None::<PathBuf>)
.await
.with_kind(ErrorKind::Filesystem)?;
let network_keys = network_keys.into_iter().collect();
let tor_keys = tor_keys.into_iter().collect();
outfile
.write_all(&IoFormat::Cbor.to_vec(&BackupMetadata {
timestamp,
network_keys,
tor_keys,
marketplace_url,
})?)
.await?;
outfile.save().await.with_kind(ErrorKind::Filesystem)?;
Ok(PackageBackupInfo {
os_version: Current::new().semver().into(),
title: manifest.title.clone(),
version: pkg_version.clone(),
timestamp,
})
}
#[instrument(skip_all)]
pub async fn restore(
&self,
ctx: &RpcContext,
pkg_id: &PackageId,
pkg_version: &Version,
volumes: &Volumes,
) -> Result<Option<Url>, Error> {
let mut volumes = volumes.clone();
volumes.insert(VolumeId::Backup, Volume::Backup { readonly: true });
self.restore
.execute::<(), NoOutput>(
ctx,
pkg_id,
pkg_version,
ProcedureName::RestoreBackup,
&volumes,
None,
None,
)
.await?
.map_err(|e| eyre!("{}", e.1))
.with_kind(crate::ErrorKind::Restore)?;
let metadata_path = Path::new(BACKUP_DIR).join(pkg_id).join("metadata.cbor");
let metadata: BackupMetadata = IoFormat::Cbor.from_slice(
&tokio::fs::read(&metadata_path).await.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
metadata_path.display().to_string(),
)
})?,
)?;
Ok(metadata.marketplace_url)
}
}

View File

@@ -0,0 +1,122 @@
use openssl::pkey::PKey;
use openssl::x509::X509;
use patch_db::Value;
use serde::{Deserialize, Serialize};
use crate::account::AccountInfo;
use crate::hostname::{generate_hostname, generate_id, Hostname};
use crate::net::keys::Key;
use crate::prelude::*;
use crate::util::serde::Base64;
pub struct OsBackup {
pub account: AccountInfo,
pub ui: Value,
}
impl<'de> Deserialize<'de> for OsBackup {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let tagged = OsBackupSerDe::deserialize(deserializer)?;
match tagged.version {
0 => patch_db::value::from_value::<OsBackupV0>(tagged.rest)
.map_err(serde::de::Error::custom)?
.project()
.map_err(serde::de::Error::custom),
1 => patch_db::value::from_value::<OsBackupV1>(tagged.rest)
.map_err(serde::de::Error::custom)?
.project()
.map_err(serde::de::Error::custom),
v => Err(serde::de::Error::custom(&format!(
"Unknown backup version {v}"
))),
}
}
}
impl Serialize for OsBackup {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
OsBackupSerDe {
version: 1,
rest: patch_db::value::to_value(
&OsBackupV1::unproject(self).map_err(serde::ser::Error::custom)?,
)
.map_err(serde::ser::Error::custom)?,
}
.serialize(serializer)
}
}
#[derive(Deserialize, Serialize)]
struct OsBackupSerDe {
#[serde(default)]
version: usize,
#[serde(flatten)]
rest: Value,
}
/// V0
#[derive(Deserialize)]
#[serde(rename = "kebab-case")]
struct OsBackupV0 {
// tor_key: Base32<[u8; 64]>,
root_ca_key: String, // PEM Encoded OpenSSL Key
root_ca_cert: String, // PEM Encoded OpenSSL X509 Certificate
ui: Value, // JSON Value
}
impl OsBackupV0 {
fn project(self) -> Result<OsBackup, Error> {
Ok(OsBackup {
account: AccountInfo {
server_id: generate_id(),
hostname: generate_hostname(),
password: Default::default(),
key: Key::new(None),
root_ca_key: PKey::private_key_from_pem(self.root_ca_key.as_bytes())?,
root_ca_cert: X509::from_pem(self.root_ca_cert.as_bytes())?,
},
ui: self.ui,
})
}
}
/// V1
#[derive(Deserialize, Serialize)]
#[serde(rename = "kebab-case")]
struct OsBackupV1 {
server_id: String, // uuidv4
hostname: String, // embassy-<adjective>-<noun>
net_key: Base64<[u8; 32]>, // Ed25519 Secret Key
root_ca_key: String, // PEM Encoded OpenSSL Key
root_ca_cert: String, // PEM Encoded OpenSSL X509 Certificate
ui: Value, // JSON Value
// TODO add more
}
impl OsBackupV1 {
fn project(self) -> Result<OsBackup, Error> {
Ok(OsBackup {
account: AccountInfo {
server_id: self.server_id,
hostname: Hostname(self.hostname),
password: Default::default(),
key: Key::from_bytes(None, self.net_key.0),
root_ca_key: PKey::private_key_from_pem(self.root_ca_key.as_bytes())?,
root_ca_cert: X509::from_pem(self.root_ca_cert.as_bytes())?,
},
ui: self.ui,
})
}
fn unproject(backup: &OsBackup) -> Result<Self, Error> {
Ok(Self {
server_id: backup.account.server_id.clone(),
hostname: backup.account.hostname.0.clone(),
net_key: Base64(backup.account.key.as_bytes()),
root_ca_key: String::from_utf8(backup.account.root_ca_key.private_key_to_pem_pkcs8()?)?,
root_ca_cert: String::from_utf8(backup.account.root_ca_cert.to_pem()?)?,
ui: backup.ui.clone(),
})
}
}

View File

@@ -0,0 +1,461 @@
use std::collections::BTreeMap;
use std::path::Path;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
use clap::ArgMatches;
use futures::future::BoxFuture;
use futures::{stream, FutureExt, StreamExt};
use openssl::x509::X509;
use rpc_toolkit::command;
use sqlx::Connection;
use tokio::fs::File;
use torut::onion::OnionAddressV3;
use tracing::instrument;
use super::target::BackupTargetId;
use crate::backup::os::OsBackup;
use crate::backup::BackupMetadata;
use crate::context::rpc::RpcContextConfig;
use crate::context::{RpcContext, SetupContext};
use crate::db::model::{PackageDataEntry, PackageDataEntryRestoring, StaticFiles};
use crate::disk::mount::backup::{BackupMountGuard, PackageBackupMountGuard};
use crate::disk::mount::filesystem::ReadWrite;
use crate::disk::mount::guard::TmpMountGuard;
use crate::hostname::Hostname;
use crate::init::init;
use crate::install::progress::InstallProgress;
use crate::install::{download_install_s9pk, PKG_PUBLIC_DIR};
use crate::notifications::NotificationLevel;
use crate::prelude::*;
use crate::s9pk::manifest::{Manifest, PackageId};
use crate::s9pk::reader::S9pkReader;
use crate::setup::SetupStatus;
use crate::util::display_none;
use crate::util::io::dir_size;
use crate::util::serde::IoFormat;
use crate::volume::{backup_dir, BACKUP_DIR, PKG_VOLUME_DIR};
fn parse_comma_separated(arg: &str, _: &ArgMatches) -> Result<Vec<PackageId>, Error> {
arg.split(',')
.map(|s| s.trim().parse().map_err(Error::from))
.collect()
}
#[command(rename = "restore", display(display_none))]
#[instrument(skip(ctx, password))]
pub async fn restore_packages_rpc(
#[context] ctx: RpcContext,
#[arg(parse(parse_comma_separated))] ids: Vec<PackageId>,
#[arg(rename = "target-id")] target_id: BackupTargetId,
#[arg] password: String,
) -> Result<(), Error> {
let fs = target_id
.load(ctx.secret_store.acquire().await?.as_mut())
.await?;
let backup_guard =
BackupMountGuard::mount(TmpMountGuard::mount(&fs, ReadWrite).await?, &password).await?;
let (backup_guard, tasks, _) = restore_packages(&ctx, backup_guard, ids).await?;
tokio::spawn(async move {
stream::iter(tasks.into_iter().map(|x| (x, ctx.clone())))
.for_each_concurrent(5, |(res, ctx)| async move {
match res.await {
(Ok(_), _) => (),
(Err(err), package_id) => {
if let Err(err) = ctx
.notification_manager
.notify(
ctx.db.clone(),
Some(package_id.clone()),
NotificationLevel::Error,
"Restoration Failure".to_string(),
format!("Error restoring package {}: {}", package_id, err),
(),
None,
)
.await
{
tracing::error!("Failed to notify: {}", err);
tracing::debug!("{:?}", err);
};
tracing::error!("Error restoring package {}: {}", package_id, err);
tracing::debug!("{:?}", err);
}
}
})
.await;
if let Err(e) = backup_guard.unmount().await {
tracing::error!("Error unmounting backup drive: {}", e);
tracing::debug!("{:?}", e);
}
});
Ok(())
}
async fn approximate_progress(
rpc_ctx: &RpcContext,
progress: &mut ProgressInfo,
) -> Result<(), Error> {
for (id, size) in &mut progress.target_volume_size {
let dir = rpc_ctx.datadir.join(PKG_VOLUME_DIR).join(id).join("data");
if tokio::fs::metadata(&dir).await.is_err() {
*size = 0;
} else {
*size = dir_size(&dir, None).await?;
}
}
Ok(())
}
async fn approximate_progress_loop(
ctx: &SetupContext,
rpc_ctx: &RpcContext,
mut starting_info: ProgressInfo,
) {
loop {
if let Err(e) = approximate_progress(rpc_ctx, &mut starting_info).await {
tracing::error!("Failed to approximate restore progress: {}", e);
tracing::debug!("{:?}", e);
} else {
*ctx.setup_status.write().await = Some(Ok(starting_info.flatten()));
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
#[derive(Debug, Default)]
struct ProgressInfo {
package_installs: BTreeMap<PackageId, Arc<InstallProgress>>,
src_volume_size: BTreeMap<PackageId, u64>,
target_volume_size: BTreeMap<PackageId, u64>,
}
impl ProgressInfo {
fn flatten(&self) -> SetupStatus {
let mut total_bytes = 0;
let mut bytes_transferred = 0;
for progress in self.package_installs.values() {
total_bytes += ((progress.size.unwrap_or(0) as f64) * 2.2) as u64;
bytes_transferred += progress.downloaded.load(Ordering::SeqCst);
bytes_transferred += ((progress.validated.load(Ordering::SeqCst) as f64) * 0.2) as u64;
bytes_transferred += progress.unpacked.load(Ordering::SeqCst);
}
for size in self.src_volume_size.values() {
total_bytes += *size;
}
for size in self.target_volume_size.values() {
bytes_transferred += *size;
}
if bytes_transferred > total_bytes {
bytes_transferred = total_bytes;
}
SetupStatus {
total_bytes: Some(total_bytes),
bytes_transferred,
complete: false,
}
}
}
#[instrument(skip(ctx))]
pub async fn recover_full_embassy(
ctx: SetupContext,
disk_guid: Arc<String>,
embassy_password: String,
recovery_source: TmpMountGuard,
recovery_password: Option<String>,
) -> Result<(Arc<String>, Hostname, OnionAddressV3, X509), Error> {
let backup_guard = BackupMountGuard::mount(
recovery_source,
recovery_password.as_deref().unwrap_or_default(),
)
.await?;
let os_backup_path = backup_guard.as_ref().join("os-backup.cbor");
let mut os_backup: OsBackup = IoFormat::Cbor.from_slice(
&tokio::fs::read(&os_backup_path)
.await
.with_ctx(|_| (ErrorKind::Filesystem, os_backup_path.display().to_string()))?,
)?;
os_backup.account.password = argon2::hash_encoded(
embassy_password.as_bytes(),
&rand::random::<[u8; 16]>()[..],
&argon2::Config::rfc9106_low_mem(),
)
.with_kind(ErrorKind::PasswordHashGeneration)?;
let secret_store = ctx.secret_store().await?;
os_backup.account.save(&secret_store).await?;
secret_store.close().await;
let cfg = RpcContextConfig::load(ctx.config_path.clone()).await?;
init(&cfg).await?;
let rpc_ctx = RpcContext::init(ctx.config_path.clone(), disk_guid.clone()).await?;
let ids: Vec<_> = backup_guard
.metadata
.package_backups
.keys()
.cloned()
.collect();
let (backup_guard, tasks, progress_info) =
restore_packages(&rpc_ctx, backup_guard, ids).await?;
let task_consumer_rpc_ctx = rpc_ctx.clone();
tokio::select! {
_ = async move {
stream::iter(tasks.into_iter().map(|x| (x, task_consumer_rpc_ctx.clone())))
.for_each_concurrent(5, |(res, ctx)| async move {
match res.await {
(Ok(_), _) => (),
(Err(err), package_id) => {
if let Err(err) = ctx.notification_manager.notify(
ctx.db.clone(),
Some(package_id.clone()),
NotificationLevel::Error,
"Restoration Failure".to_string(), format!("Error restoring package {}: {}", package_id,err), (), None).await{
tracing::error!("Failed to notify: {}", err);
tracing::debug!("{:?}", err);
};
tracing::error!("Error restoring package {}: {}", package_id, err);
tracing::debug!("{:?}", err);
},
}
}).await;
} => {
},
_ = approximate_progress_loop(&ctx, &rpc_ctx, progress_info) => unreachable!(concat!(module_path!(), "::approximate_progress_loop should not terminate")),
}
backup_guard.unmount().await?;
rpc_ctx.shutdown().await?;
Ok((
disk_guid,
os_backup.account.hostname,
os_backup.account.key.tor_address(),
os_backup.account.root_ca_cert,
))
}
#[instrument(skip(ctx, backup_guard))]
async fn restore_packages(
ctx: &RpcContext,
backup_guard: BackupMountGuard<TmpMountGuard>,
ids: Vec<PackageId>,
) -> Result<
(
BackupMountGuard<TmpMountGuard>,
Vec<BoxFuture<'static, (Result<(), Error>, PackageId)>>,
ProgressInfo,
),
Error,
> {
let guards = assure_restoring(ctx, ids, &backup_guard).await?;
let mut progress_info = ProgressInfo::default();
let mut tasks = Vec::with_capacity(guards.len());
for (manifest, guard) in guards {
let id = manifest.id.clone();
let (progress, task) = restore_package(ctx.clone(), manifest, guard).await?;
progress_info
.package_installs
.insert(id.clone(), progress.clone());
progress_info
.src_volume_size
.insert(id.clone(), dir_size(backup_dir(&id), None).await?);
progress_info.target_volume_size.insert(id.clone(), 0);
let package_id = id.clone();
tasks.push(
async move {
if let Err(e) = task.await {
tracing::error!("Error restoring package {}: {}", id, e);
tracing::debug!("{:?}", e);
Err(e)
} else {
Ok(())
}
}
.map(|x| (x, package_id))
.boxed(),
);
}
Ok((backup_guard, tasks, progress_info))
}
#[instrument(skip(ctx, backup_guard))]
async fn assure_restoring(
ctx: &RpcContext,
ids: Vec<PackageId>,
backup_guard: &BackupMountGuard<TmpMountGuard>,
) -> Result<Vec<(Manifest, PackageBackupMountGuard)>, Error> {
let mut guards = Vec::with_capacity(ids.len());
let mut insert_packages = BTreeMap::new();
for id in ids {
let peek = ctx.db.peek().await;
let model = peek.as_package_data().as_idx(&id);
if !model.is_none() {
return Err(Error::new(
eyre!("Can't restore over existing package: {}", id),
crate::ErrorKind::InvalidRequest,
));
}
let guard = backup_guard.mount_package_backup(&id).await?;
let s9pk_path = Path::new(BACKUP_DIR).join(&id).join(format!("{}.s9pk", id));
let mut rdr = S9pkReader::open(&s9pk_path, false).await?;
let manifest = rdr.manifest().await?;
let version = manifest.version.clone();
let progress = Arc::new(InstallProgress::new(Some(
tokio::fs::metadata(&s9pk_path).await?.len(),
)));
let public_dir_path = ctx
.datadir
.join(PKG_PUBLIC_DIR)
.join(&id)
.join(version.as_str());
tokio::fs::create_dir_all(&public_dir_path).await?;
let license_path = public_dir_path.join("LICENSE.md");
let mut dst = File::create(&license_path).await?;
tokio::io::copy(&mut rdr.license().await?, &mut dst).await?;
dst.sync_all().await?;
let instructions_path = public_dir_path.join("INSTRUCTIONS.md");
let mut dst = File::create(&instructions_path).await?;
tokio::io::copy(&mut rdr.instructions().await?, &mut dst).await?;
dst.sync_all().await?;
let icon_path = Path::new("icon").with_extension(&manifest.assets.icon_type());
let icon_path = public_dir_path.join(&icon_path);
let mut dst = File::create(&icon_path).await?;
tokio::io::copy(&mut rdr.icon().await?, &mut dst).await?;
dst.sync_all().await?;
insert_packages.insert(
id.clone(),
PackageDataEntry::Restoring(PackageDataEntryRestoring {
install_progress: progress.clone(),
static_files: StaticFiles::local(&id, &version, manifest.assets.icon_type()),
manifest: manifest.clone(),
}),
);
guards.push((manifest, guard));
}
ctx.db
.mutate(|db| {
for (id, package) in insert_packages {
db.as_package_data_mut().insert(&id, &package)?;
}
Ok(())
})
.await?;
Ok(guards)
}
#[instrument(skip(ctx, guard))]
async fn restore_package<'a>(
ctx: RpcContext,
manifest: Manifest,
guard: PackageBackupMountGuard,
) -> Result<(Arc<InstallProgress>, BoxFuture<'static, Result<(), Error>>), Error> {
let id = manifest.id.clone();
let s9pk_path = Path::new(BACKUP_DIR)
.join(&manifest.id)
.join(format!("{}.s9pk", id));
let metadata_path = Path::new(BACKUP_DIR).join(&id).join("metadata.cbor");
let metadata: BackupMetadata = IoFormat::Cbor.from_slice(
&tokio::fs::read(&metadata_path)
.await
.with_ctx(|_| (ErrorKind::Filesystem, metadata_path.display().to_string()))?,
)?;
let mut secrets = ctx.secret_store.acquire().await?;
let mut secrets_tx = secrets.begin().await?;
for (iface, key) in metadata.network_keys {
let k = key.0.as_slice();
sqlx::query!(
"INSERT INTO network_keys (package, interface, key) VALUES ($1, $2, $3) ON CONFLICT (package, interface) DO NOTHING",
id.to_string(),
iface.to_string(),
k,
)
.execute(secrets_tx.as_mut()).await?;
}
// DEPRECATED
for (iface, key) in metadata.tor_keys {
let k = key.0.as_slice();
sqlx::query!(
"INSERT INTO tor (package, interface, key) VALUES ($1, $2, $3) ON CONFLICT (package, interface) DO NOTHING",
id.to_string(),
iface.to_string(),
k,
)
.execute(secrets_tx.as_mut()).await?;
}
secrets_tx.commit().await?;
drop(secrets);
let len = tokio::fs::metadata(&s9pk_path)
.await
.with_ctx(|_| (ErrorKind::Filesystem, s9pk_path.display().to_string()))?
.len();
let file = File::open(&s9pk_path)
.await
.with_ctx(|_| (ErrorKind::Filesystem, s9pk_path.display().to_string()))?;
let progress = InstallProgress::new(Some(len));
let marketplace_url = metadata.marketplace_url;
let progress = Arc::new(progress);
ctx.db
.mutate(|db| {
db.as_package_data_mut().insert(
&id,
&PackageDataEntry::Restoring(PackageDataEntryRestoring {
install_progress: progress.clone(),
static_files: StaticFiles::local(
&id,
&manifest.version,
manifest.assets.icon_type(),
),
manifest: manifest.clone(),
}),
)
})
.await?;
Ok((
progress.clone(),
async move {
download_install_s9pk(ctx, manifest, marketplace_url, progress, file, None).await?;
guard.unmount().await?;
Ok(())
}
.boxed(),
))
}

View File

@@ -0,0 +1,211 @@
use std::path::{Path, PathBuf};
use color_eyre::eyre::eyre;
use futures::TryStreamExt;
use rpc_toolkit::command;
use serde::{Deserialize, Serialize};
use sqlx::{Executor, Postgres};
use super::{BackupTarget, BackupTargetId};
use crate::context::RpcContext;
use crate::disk::mount::filesystem::cifs::Cifs;
use crate::disk::mount::filesystem::ReadOnly;
use crate::disk::mount::guard::TmpMountGuard;
use crate::disk::util::{recovery_info, EmbassyOsRecoveryInfo};
use crate::prelude::*;
use crate::util::display_none;
use crate::util::serde::KeyVal;
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct CifsBackupTarget {
hostname: String,
path: PathBuf,
username: String,
mountable: bool,
embassy_os: Option<EmbassyOsRecoveryInfo>,
}
#[command(subcommands(add, update, remove))]
pub fn cifs() -> Result<(), Error> {
Ok(())
}
#[command(display(display_none))]
pub async fn add(
#[context] ctx: RpcContext,
#[arg] hostname: String,
#[arg] path: PathBuf,
#[arg] username: String,
#[arg] password: Option<String>,
) -> Result<KeyVal<BackupTargetId, BackupTarget>, Error> {
let cifs = Cifs {
hostname,
path,
username,
password,
};
let guard = TmpMountGuard::mount(&cifs, ReadOnly).await?;
let embassy_os = recovery_info(&guard).await?;
guard.unmount().await?;
let path_string = Path::new("/").join(&cifs.path).display().to_string();
let id: i32 = sqlx::query!(
"INSERT INTO cifs_shares (hostname, path, username, password) VALUES ($1, $2, $3, $4) RETURNING id",
cifs.hostname,
path_string,
cifs.username,
cifs.password,
)
.fetch_one(&ctx.secret_store)
.await?.id;
Ok(KeyVal {
key: BackupTargetId::Cifs { id },
value: BackupTarget::Cifs(CifsBackupTarget {
hostname: cifs.hostname,
path: cifs.path,
username: cifs.username,
mountable: true,
embassy_os,
}),
})
}
#[command(display(display_none))]
pub async fn update(
#[context] ctx: RpcContext,
#[arg] id: BackupTargetId,
#[arg] hostname: String,
#[arg] path: PathBuf,
#[arg] username: String,
#[arg] password: Option<String>,
) -> Result<KeyVal<BackupTargetId, BackupTarget>, Error> {
let id = if let BackupTargetId::Cifs { id } = id {
id
} else {
return Err(Error::new(
eyre!("Backup Target ID {} Not Found", id),
ErrorKind::NotFound,
));
};
let cifs = Cifs {
hostname,
path,
username,
password,
};
let guard = TmpMountGuard::mount(&cifs, ReadOnly).await?;
let embassy_os = recovery_info(&guard).await?;
guard.unmount().await?;
let path_string = Path::new("/").join(&cifs.path).display().to_string();
if sqlx::query!(
"UPDATE cifs_shares SET hostname = $1, path = $2, username = $3, password = $4 WHERE id = $5",
cifs.hostname,
path_string,
cifs.username,
cifs.password,
id,
)
.execute(&ctx.secret_store)
.await?
.rows_affected()
== 0
{
return Err(Error::new(
eyre!("Backup Target ID {} Not Found", BackupTargetId::Cifs { id }),
ErrorKind::NotFound,
));
};
Ok(KeyVal {
key: BackupTargetId::Cifs { id },
value: BackupTarget::Cifs(CifsBackupTarget {
hostname: cifs.hostname,
path: cifs.path,
username: cifs.username,
mountable: true,
embassy_os,
}),
})
}
#[command(display(display_none))]
pub async fn remove(#[context] ctx: RpcContext, #[arg] id: BackupTargetId) -> Result<(), Error> {
let id = if let BackupTargetId::Cifs { id } = id {
id
} else {
return Err(Error::new(
eyre!("Backup Target ID {} Not Found", id),
ErrorKind::NotFound,
));
};
if sqlx::query!("DELETE FROM cifs_shares WHERE id = $1", id)
.execute(&ctx.secret_store)
.await?
.rows_affected()
== 0
{
return Err(Error::new(
eyre!("Backup Target ID {} Not Found", BackupTargetId::Cifs { id }),
ErrorKind::NotFound,
));
};
Ok(())
}
pub async fn load<Ex>(secrets: &mut Ex, id: i32) -> Result<Cifs, Error>
where
for<'a> &'a mut Ex: Executor<'a, Database = Postgres>,
{
let record = sqlx::query!(
"SELECT hostname, path, username, password FROM cifs_shares WHERE id = $1",
id
)
.fetch_one(secrets)
.await?;
Ok(Cifs {
hostname: record.hostname,
path: PathBuf::from(record.path),
username: record.username,
password: record.password,
})
}
pub async fn list<Ex>(secrets: &mut Ex) -> Result<Vec<(i32, CifsBackupTarget)>, Error>
where
for<'a> &'a mut Ex: Executor<'a, Database = Postgres>,
{
let mut records =
sqlx::query!("SELECT id, hostname, path, username, password FROM cifs_shares")
.fetch_many(secrets);
let mut cifs = Vec::new();
while let Some(query_result) = records.try_next().await? {
if let Some(record) = query_result.right() {
let mount_info = Cifs {
hostname: record.hostname,
path: PathBuf::from(record.path),
username: record.username,
password: record.password,
};
let embassy_os = async {
let guard = TmpMountGuard::mount(&mount_info, ReadOnly).await?;
let embassy_os = recovery_info(&guard).await?;
guard.unmount().await?;
Ok::<_, Error>(embassy_os)
}
.await;
cifs.push((
record.id,
CifsBackupTarget {
hostname: mount_info.hostname,
path: mount_info.path,
username: mount_info.username,
mountable: embassy_os.is_ok(),
embassy_os: embassy_os.ok().and_then(|a| a),
},
));
}
}
Ok(cifs)
}

View File

@@ -0,0 +1,307 @@
use std::collections::BTreeMap;
use std::path::{Path, PathBuf};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use clap::ArgMatches;
use color_eyre::eyre::eyre;
use digest::generic_array::GenericArray;
use digest::OutputSizeUser;
use rpc_toolkit::command;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use sqlx::{Executor, Postgres};
use tokio::sync::Mutex;
use tracing::instrument;
use self::cifs::CifsBackupTarget;
use crate::context::RpcContext;
use crate::disk::mount::backup::BackupMountGuard;
use crate::disk::mount::filesystem::block_dev::BlockDev;
use crate::disk::mount::filesystem::cifs::Cifs;
use crate::disk::mount::filesystem::{FileSystem, MountType, ReadWrite};
use crate::disk::mount::guard::TmpMountGuard;
use crate::disk::util::PartitionInfo;
use crate::prelude::*;
use crate::s9pk::manifest::PackageId;
use crate::util::serde::{deserialize_from_str, display_serializable, serialize_display};
use crate::util::{display_none, Version};
pub mod cifs;
#[derive(Debug, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "kebab-case")]
pub enum BackupTarget {
#[serde(rename_all = "kebab-case")]
Disk {
vendor: Option<String>,
model: Option<String>,
#[serde(flatten)]
partition_info: PartitionInfo,
},
Cifs(CifsBackupTarget),
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub enum BackupTargetId {
Disk { logicalname: PathBuf },
Cifs { id: i32 },
}
impl BackupTargetId {
pub async fn load<Ex>(self, secrets: &mut Ex) -> Result<BackupTargetFS, Error>
where
for<'a> &'a mut Ex: Executor<'a, Database = Postgres>,
{
Ok(match self {
BackupTargetId::Disk { logicalname } => {
BackupTargetFS::Disk(BlockDev::new(logicalname))
}
BackupTargetId::Cifs { id } => BackupTargetFS::Cifs(cifs::load(secrets, id).await?),
})
}
}
impl std::fmt::Display for BackupTargetId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BackupTargetId::Disk { logicalname } => write!(f, "disk-{}", logicalname.display()),
BackupTargetId::Cifs { id } => write!(f, "cifs-{}", id),
}
}
}
impl std::str::FromStr for BackupTargetId {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.split_once('-') {
Some(("disk", logicalname)) => Ok(BackupTargetId::Disk {
logicalname: Path::new(logicalname).to_owned(),
}),
Some(("cifs", id)) => Ok(BackupTargetId::Cifs { id: id.parse()? }),
_ => Err(Error::new(
eyre!("Invalid Backup Target ID"),
ErrorKind::InvalidBackupTargetId,
)),
}
}
}
impl<'de> Deserialize<'de> for BackupTargetId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserialize_from_str(deserializer)
}
}
impl Serialize for BackupTargetId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serialize_display(self, serializer)
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(tag = "type")]
#[serde(rename_all = "kebab-case")]
pub enum BackupTargetFS {
Disk(BlockDev<PathBuf>),
Cifs(Cifs),
}
#[async_trait]
impl FileSystem for BackupTargetFS {
async fn mount<P: AsRef<Path> + Send + Sync>(
&self,
mountpoint: P,
mount_type: MountType,
) -> Result<(), Error> {
match self {
BackupTargetFS::Disk(a) => a.mount(mountpoint, mount_type).await,
BackupTargetFS::Cifs(a) => a.mount(mountpoint, mount_type).await,
}
}
async fn source_hash(
&self,
) -> Result<GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize>, Error> {
match self {
BackupTargetFS::Disk(a) => a.source_hash().await,
BackupTargetFS::Cifs(a) => a.source_hash().await,
}
}
}
#[command(subcommands(cifs::cifs, list, info, mount, umount))]
pub fn target() -> Result<(), Error> {
Ok(())
}
#[command(display(display_serializable))]
pub async fn list(
#[context] ctx: RpcContext,
) -> Result<BTreeMap<BackupTargetId, BackupTarget>, Error> {
let mut sql_handle = ctx.secret_store.acquire().await?;
let (disks_res, cifs) = tokio::try_join!(
crate::disk::util::list(&ctx.os_partitions),
cifs::list(sql_handle.as_mut()),
)?;
Ok(disks_res
.into_iter()
.flat_map(|mut disk| {
std::mem::take(&mut disk.partitions)
.into_iter()
.map(|part| {
(
BackupTargetId::Disk {
logicalname: part.logicalname.clone(),
},
BackupTarget::Disk {
vendor: disk.vendor.clone(),
model: disk.model.clone(),
partition_info: part,
},
)
})
.collect::<Vec<_>>()
})
.chain(
cifs.into_iter()
.map(|(id, cifs)| (BackupTargetId::Cifs { id }, BackupTarget::Cifs(cifs))),
)
.collect())
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct BackupInfo {
pub version: Version,
pub timestamp: Option<DateTime<Utc>>,
pub package_backups: BTreeMap<PackageId, PackageBackupInfo>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct PackageBackupInfo {
pub title: String,
pub version: Version,
pub os_version: Version,
pub timestamp: DateTime<Utc>,
}
fn display_backup_info(info: BackupInfo, matches: &ArgMatches) {
use prettytable::*;
if matches.is_present("format") {
return display_serializable(info, matches);
}
let mut table = Table::new();
table.add_row(row![bc =>
"ID",
"VERSION",
"OS VERSION",
"TIMESTAMP",
]);
table.add_row(row![
"EMBASSY OS",
info.version.as_str(),
info.version.as_str(),
&if let Some(ts) = &info.timestamp {
ts.to_string()
} else {
"N/A".to_owned()
},
]);
for (id, info) in info.package_backups {
let row = row![
&*id,
info.version.as_str(),
info.os_version.as_str(),
&info.timestamp.to_string(),
];
table.add_row(row);
}
table.print_tty(false).unwrap();
}
#[command(display(display_backup_info))]
#[instrument(skip(ctx, password))]
pub async fn info(
#[context] ctx: RpcContext,
#[arg(rename = "target-id")] target_id: BackupTargetId,
#[arg] password: String,
) -> Result<BackupInfo, Error> {
let guard = BackupMountGuard::mount(
TmpMountGuard::mount(
&target_id
.load(ctx.secret_store.acquire().await?.as_mut())
.await?,
ReadWrite,
)
.await?,
&password,
)
.await?;
let res = guard.metadata.clone();
guard.unmount().await?;
Ok(res)
}
lazy_static::lazy_static! {
static ref USER_MOUNTS: Mutex<BTreeMap<BackupTargetId, BackupMountGuard<TmpMountGuard>>> =
Mutex::new(BTreeMap::new());
}
#[command]
#[instrument(skip_all)]
pub async fn mount(
#[context] ctx: RpcContext,
#[arg(rename = "target-id")] target_id: BackupTargetId,
#[arg] password: String,
) -> Result<String, Error> {
let mut mounts = USER_MOUNTS.lock().await;
if let Some(existing) = mounts.get(&target_id) {
return Ok(existing.as_ref().display().to_string());
}
let guard = BackupMountGuard::mount(
TmpMountGuard::mount(
&target_id
.clone()
.load(ctx.secret_store.acquire().await?.as_mut())
.await?,
ReadWrite,
)
.await?,
&password,
)
.await?;
let res = guard.as_ref().display().to_string();
mounts.insert(target_id, guard);
Ok(res)
}
#[command(display(display_none))]
#[instrument(skip_all)]
pub async fn umount(
#[context] _ctx: RpcContext,
#[arg(rename = "target-id")] target_id: Option<BackupTargetId>,
) -> Result<(), Error> {
let mut mounts = USER_MOUNTS.lock().await;
if let Some(target_id) = target_id {
if let Some(existing) = mounts.remove(&target_id) {
existing.unmount().await?;
}
} else {
for (_, existing) in std::mem::take(&mut *mounts) {
existing.unmount().await?;
}
}
Ok(())
}

View File

@@ -0,0 +1,163 @@
use avahi_sys::{
self, avahi_client_errno, avahi_entry_group_add_service, avahi_entry_group_commit,
avahi_strerror, AvahiClient,
};
fn log_str_error(action: &str, e: i32) {
unsafe {
let e_str = avahi_strerror(e);
eprintln!(
"Could not {}: {:?}",
action,
std::ffi::CStr::from_ptr(e_str)
);
}
}
pub fn main() {
let aliases: Vec<_> = std::env::args().skip(1).collect();
unsafe {
let simple_poll = avahi_sys::avahi_simple_poll_new();
let poll = avahi_sys::avahi_simple_poll_get(simple_poll);
let mut box_err = Box::pin(0 as i32);
let err_c: *mut i32 = box_err.as_mut().get_mut();
let avahi_client = avahi_sys::avahi_client_new(
poll,
avahi_sys::AvahiClientFlags::AVAHI_CLIENT_NO_FAIL,
Some(client_callback),
std::ptr::null_mut(),
err_c,
);
if avahi_client == std::ptr::null_mut::<AvahiClient>() {
log_str_error("create Avahi client", *box_err);
panic!("Failed to create Avahi Client");
}
let group = avahi_sys::avahi_entry_group_new(
avahi_client,
Some(entry_group_callback),
std::ptr::null_mut(),
);
if group == std::ptr::null_mut() {
log_str_error("create Avahi entry group", avahi_client_errno(avahi_client));
panic!("Failed to create Avahi Entry Group");
}
let mut hostname_buf = vec![0];
let hostname_raw = avahi_sys::avahi_client_get_host_name_fqdn(avahi_client);
hostname_buf.extend_from_slice(std::ffi::CStr::from_ptr(hostname_raw).to_bytes_with_nul());
let buflen = hostname_buf.len();
debug_assert!(hostname_buf.ends_with(b".local\0"));
debug_assert!(!hostname_buf[..(buflen - 7)].contains(&b'.'));
// assume fixed length prefix on hostname due to local address
hostname_buf[0] = (buflen - 8) as u8; // set the prefix length to len - 8 (leading byte, .local, nul) for the main address
hostname_buf[buflen - 7] = 5; // set the prefix length to 5 for "local"
let mut res;
let http_tcp_cstr =
std::ffi::CString::new("_http._tcp").expect("Could not cast _http._tcp to c string");
res = avahi_entry_group_add_service(
group,
avahi_sys::AVAHI_IF_UNSPEC,
avahi_sys::AVAHI_PROTO_UNSPEC,
avahi_sys::AvahiPublishFlags_AVAHI_PUBLISH_USE_MULTICAST,
hostname_raw,
http_tcp_cstr.as_ptr(),
std::ptr::null(),
std::ptr::null(),
443,
// below is a secret final argument that the type signature of this function does not tell you that it
// needs. This is because the C lib function takes a variable number of final arguments indicating the
// desired TXT records to add to this service entry. The way it decides when to stop taking arguments
// from the stack and dereferencing them is when it finds a null pointer...because fuck you, that's why.
// The consequence of this is that forgetting this last argument will cause segfaults or other undefined
// behavior. Welcome back to the stone age motherfucker.
std::ptr::null::<libc::c_char>(),
);
if res < avahi_sys::AVAHI_OK {
log_str_error("add service to Avahi entry group", res);
panic!("Failed to load Avahi services");
}
eprintln!("Published {:?}", std::ffi::CStr::from_ptr(hostname_raw));
for alias in aliases {
let lan_address = alias + ".local";
let lan_address_ptr = std::ffi::CString::new(lan_address)
.expect("Could not cast lan address to c string");
res = avahi_sys::avahi_entry_group_add_record(
group,
avahi_sys::AVAHI_IF_UNSPEC,
avahi_sys::AVAHI_PROTO_UNSPEC,
avahi_sys::AvahiPublishFlags_AVAHI_PUBLISH_USE_MULTICAST
| avahi_sys::AvahiPublishFlags_AVAHI_PUBLISH_ALLOW_MULTIPLE,
lan_address_ptr.as_ptr(),
avahi_sys::AVAHI_DNS_CLASS_IN as u16,
avahi_sys::AVAHI_DNS_TYPE_CNAME as u16,
avahi_sys::AVAHI_DEFAULT_TTL,
hostname_buf.as_ptr().cast(),
hostname_buf.len(),
);
if res < avahi_sys::AVAHI_OK {
log_str_error("add CNAME record to Avahi entry group", res);
panic!("Failed to load Avahi services");
}
eprintln!("Published {:?}", lan_address_ptr);
}
let commit_err = avahi_entry_group_commit(group);
if commit_err < avahi_sys::AVAHI_OK {
log_str_error("reset Avahi entry group", commit_err);
panic!("Failed to load Avahi services: reset");
}
}
std::thread::park()
}
unsafe extern "C" fn entry_group_callback(
_group: *mut avahi_sys::AvahiEntryGroup,
state: avahi_sys::AvahiEntryGroupState,
_userdata: *mut core::ffi::c_void,
) {
match state {
avahi_sys::AvahiEntryGroupState_AVAHI_ENTRY_GROUP_FAILURE => {
eprintln!("AvahiCallback: EntryGroupState = AVAHI_ENTRY_GROUP_FAILURE");
}
avahi_sys::AvahiEntryGroupState_AVAHI_ENTRY_GROUP_COLLISION => {
eprintln!("AvahiCallback: EntryGroupState = AVAHI_ENTRY_GROUP_COLLISION");
}
avahi_sys::AvahiEntryGroupState_AVAHI_ENTRY_GROUP_UNCOMMITED => {
eprintln!("AvahiCallback: EntryGroupState = AVAHI_ENTRY_GROUP_UNCOMMITED");
}
avahi_sys::AvahiEntryGroupState_AVAHI_ENTRY_GROUP_ESTABLISHED => {
eprintln!("AvahiCallback: EntryGroupState = AVAHI_ENTRY_GROUP_ESTABLISHED");
}
avahi_sys::AvahiEntryGroupState_AVAHI_ENTRY_GROUP_REGISTERING => {
eprintln!("AvahiCallback: EntryGroupState = AVAHI_ENTRY_GROUP_REGISTERING");
}
other => {
eprintln!("AvahiCallback: EntryGroupState = {}", other);
}
}
}
unsafe extern "C" fn client_callback(
_group: *mut avahi_sys::AvahiClient,
state: avahi_sys::AvahiClientState,
_userdata: *mut core::ffi::c_void,
) {
match state {
avahi_sys::AvahiClientState_AVAHI_CLIENT_FAILURE => {
eprintln!("AvahiCallback: ClientState = AVAHI_CLIENT_FAILURE");
}
avahi_sys::AvahiClientState_AVAHI_CLIENT_S_RUNNING => {
eprintln!("AvahiCallback: ClientState = AVAHI_CLIENT_S_RUNNING");
}
avahi_sys::AvahiClientState_AVAHI_CLIENT_CONNECTING => {
eprintln!("AvahiCallback: ClientState = AVAHI_CLIENT_CONNECTING");
}
avahi_sys::AvahiClientState_AVAHI_CLIENT_S_COLLISION => {
eprintln!("AvahiCallback: ClientState = AVAHI_CLIENT_S_COLLISION");
}
avahi_sys::AvahiClientState_AVAHI_CLIENT_S_REGISTERING => {
eprintln!("AvahiCallback: ClientState = AVAHI_CLIENT_S_REGISTERING");
}
other => {
eprintln!("AvahiCallback: ClientState = {}", other);
}
}
}

View File

@@ -0,0 +1,9 @@
pub fn renamed(old: &str, new: &str) -> ! {
eprintln!("{old} has been renamed to {new}");
std::process::exit(1)
}
pub fn removed(name: &str) -> ! {
eprintln!("{name} has been removed");
std::process::exit(1)
}

View File

@@ -0,0 +1,59 @@
use std::path::Path;
#[cfg(feature = "avahi-alias")]
pub mod avahi_alias;
pub mod deprecated;
#[cfg(feature = "cli")]
pub mod start_cli;
#[cfg(feature = "js-engine")]
pub mod start_deno;
#[cfg(feature = "daemon")]
pub mod start_init;
#[cfg(feature = "sdk")]
pub mod start_sdk;
#[cfg(feature = "daemon")]
pub mod startd;
fn select_executable(name: &str) -> Option<fn()> {
match name {
#[cfg(feature = "avahi-alias")]
"avahi-alias" => Some(avahi_alias::main),
#[cfg(feature = "js_engine")]
"start-deno" => Some(start_deno::main),
#[cfg(feature = "cli")]
"start-cli" => Some(start_cli::main),
#[cfg(feature = "sdk")]
"start-sdk" => Some(start_sdk::main),
#[cfg(feature = "daemon")]
"startd" => Some(startd::main),
"embassy-cli" => Some(|| deprecated::renamed("embassy-cli", "start-cli")),
"embassy-sdk" => Some(|| deprecated::renamed("embassy-sdk", "start-sdk")),
"embassyd" => Some(|| deprecated::renamed("embassyd", "startd")),
"embassy-init" => Some(|| deprecated::removed("embassy-init")),
_ => None,
}
}
pub fn startbox() {
let args = std::env::args().take(2).collect::<Vec<_>>();
if let Some(x) = args
.get(0)
.and_then(|s| Path::new(&*s).file_name())
.and_then(|s| s.to_str())
.and_then(|s| select_executable(&s))
{
x()
} else if let Some(x) = args.get(1).and_then(|s| select_executable(&s)) {
x()
} else {
eprintln!(
"unknown executable: {}",
args.get(0)
.filter(|x| &**x != "startbox")
.or_else(|| args.get(1))
.map(|s| s.as_str())
.unwrap_or("N/A")
);
std::process::exit(1);
}
}

View File

@@ -0,0 +1,62 @@
use clap::Arg;
use rpc_toolkit::run_cli;
use rpc_toolkit::yajrc::RpcError;
use serde_json::Value;
use crate::context::CliContext;
use crate::util::logger::EmbassyLogger;
use crate::version::{Current, VersionT};
use crate::Error;
lazy_static::lazy_static! {
static ref VERSION_STRING: String = Current::new().semver().to_string();
}
fn inner_main() -> Result<(), Error> {
run_cli!({
command: crate::main_api,
app: app => app
.name("StartOS CLI")
.version(&**VERSION_STRING)
.arg(
clap::Arg::with_name("config")
.short('c')
.long("config")
.takes_value(true),
)
.arg(Arg::with_name("host").long("host").short('h').takes_value(true))
.arg(Arg::with_name("proxy").long("proxy").short('p').takes_value(true)),
context: matches => {
EmbassyLogger::init();
CliContext::init(matches)?
},
exit: |e: RpcError| {
match e.data {
Some(Value::String(s)) => eprintln!("{}: {}", e.message, s),
Some(Value::Object(o)) => if let Some(Value::String(s)) = o.get("details") {
eprintln!("{}: {}", e.message, s);
if let Some(Value::String(s)) = o.get("debug") {
tracing::debug!("{}", s)
}
}
Some(a) => eprintln!("{}: {}", e.message, a),
None => eprintln!("{}", e.message),
}
std::process::exit(e.code);
}
});
Ok(())
}
pub fn main() {
match inner_main() {
Ok(_) => (),
Err(e) => {
eprintln!("{}", e.source);
tracing::debug!("{:?}", e.source);
drop(e.source);
std::process::exit(e.kind as i32)
}
}
}

View File

@@ -0,0 +1,140 @@
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{command, run_cli, Context};
use serde_json::Value;
use crate::procedure::js_scripts::ExecuteArgs;
use crate::s9pk::manifest::PackageId;
use crate::util::serde::{display_serializable, parse_stdin_deserializable, IoFormat};
use crate::version::{Current, VersionT};
use crate::Error;
lazy_static::lazy_static! {
static ref VERSION_STRING: String = Current::new().semver().to_string();
}
struct DenoContext;
impl Context for DenoContext {}
#[command(subcommands(execute, sandbox))]
fn deno_api() -> Result<(), Error> {
Ok(())
}
#[command(cli_only, display(display_serializable))]
async fn execute(
#[arg(stdin, parse(parse_stdin_deserializable))] arg: ExecuteArgs,
#[allow(unused_variables)]
#[arg(long = "format")]
format: Option<IoFormat>,
) -> Result<Result<Value, (i32, String)>, Error> {
let ExecuteArgs {
procedure,
directory,
pkg_id,
pkg_version,
name,
volumes,
input,
} = arg;
PackageLogger::init(&pkg_id);
procedure
.execute_impl(&directory, &pkg_id, &pkg_version, name, &volumes, input)
.await
}
#[command(cli_only, display(display_serializable))]
async fn sandbox(
#[arg(stdin, parse(parse_stdin_deserializable))] arg: ExecuteArgs,
#[allow(unused_variables)]
#[arg(long = "format")]
format: Option<IoFormat>,
) -> Result<Result<Value, (i32, String)>, Error> {
let ExecuteArgs {
procedure,
directory,
pkg_id,
pkg_version,
name,
volumes,
input,
} = arg;
PackageLogger::init(&pkg_id);
procedure
.sandboxed_impl(&directory, &pkg_id, &pkg_version, &volumes, input, name)
.await
}
use tracing::Subscriber;
use tracing_subscriber::util::SubscriberInitExt;
#[derive(Clone)]
struct PackageLogger {}
impl PackageLogger {
fn base_subscriber(id: &PackageId) -> impl Subscriber {
use tracing_error::ErrorLayer;
use tracing_subscriber::prelude::*;
use tracing_subscriber::{fmt, EnvFilter};
let filter_layer = EnvFilter::default().add_directive(
format!("{}=warn", std::module_path!().split("::").next().unwrap())
.parse()
.unwrap(),
);
let fmt_layer = fmt::layer().with_writer(std::io::stderr).with_target(true);
let journald_layer = tracing_journald::layer()
.unwrap()
.with_syslog_identifier(format!("{id}.embassy"));
let sub = tracing_subscriber::registry()
.with(filter_layer)
.with(fmt_layer)
.with(journald_layer)
.with(ErrorLayer::default());
sub
}
pub fn init(id: &PackageId) -> Self {
Self::base_subscriber(id).init();
color_eyre::install().unwrap_or_else(|_| tracing::warn!("tracing too many times"));
Self {}
}
}
fn inner_main() -> Result<(), Error> {
run_cli!({
command: deno_api,
app: app => app
.name("StartOS Deno Executor")
.version(&**VERSION_STRING),
context: _m => DenoContext,
exit: |e: RpcError| {
match e.data {
Some(Value::String(s)) => eprintln!("{}: {}", e.message, s),
Some(Value::Object(o)) => if let Some(Value::String(s)) = o.get("details") {
eprintln!("{}: {}", e.message, s);
if let Some(Value::String(s)) = o.get("debug") {
tracing::debug!("{}", s)
}
}
Some(a) => eprintln!("{}: {}", e.message, a),
None => eprintln!("{}", e.message),
}
std::process::exit(e.code);
}
});
Ok(())
}
pub fn main() {
match inner_main() {
Ok(_) => (),
Err(e) => {
eprintln!("{}", e.source);
tracing::debug!("{:?}", e.source);
drop(e.source);
std::process::exit(e.kind as i32)
}
}
}

View File

@@ -0,0 +1,268 @@
use std::net::{Ipv6Addr, SocketAddr};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use tokio::process::Command;
use tracing::instrument;
use crate::context::rpc::RpcContextConfig;
use crate::context::{DiagnosticContext, InstallContext, SetupContext};
use crate::disk::fsck::RepairStrategy;
use crate::disk::main::DEFAULT_PASSWORD;
use crate::disk::REPAIR_DISK_PATH;
use crate::firmware::update_firmware;
use crate::init::STANDBY_MODE_PATH;
use crate::net::web_server::WebServer;
use crate::shutdown::Shutdown;
use crate::sound::CHIME;
use crate::util::Invoke;
use crate::{Error, ErrorKind, ResultExt, PLATFORM};
#[instrument(skip_all)]
async fn setup_or_init(cfg_path: Option<PathBuf>) -> Result<Option<Shutdown>, Error> {
if update_firmware().await?.0 {
return Ok(Some(Shutdown {
export_args: None,
restart: true,
}));
}
Command::new("ln")
.arg("-sf")
.arg("/usr/lib/startos/scripts/fake-apt")
.arg("/usr/local/bin/apt")
.invoke(crate::ErrorKind::Filesystem)
.await?;
Command::new("ln")
.arg("-sf")
.arg("/usr/lib/startos/scripts/fake-apt")
.arg("/usr/local/bin/apt-get")
.invoke(crate::ErrorKind::Filesystem)
.await?;
Command::new("ln")
.arg("-sf")
.arg("/usr/lib/startos/scripts/fake-apt")
.arg("/usr/local/bin/aptitude")
.invoke(crate::ErrorKind::Filesystem)
.await?;
Command::new("make-ssl-cert")
.arg("generate-default-snakeoil")
.arg("--force-overwrite")
.invoke(crate::ErrorKind::OpenSsl)
.await?;
if tokio::fs::metadata("/run/live/medium").await.is_ok() {
Command::new("sed")
.arg("-i")
.arg("s/PasswordAuthentication no/PasswordAuthentication yes/g")
.arg("/etc/ssh/sshd_config")
.invoke(crate::ErrorKind::Filesystem)
.await?;
Command::new("systemctl")
.arg("reload")
.arg("ssh")
.invoke(crate::ErrorKind::OpenSsh)
.await?;
let ctx = InstallContext::init(cfg_path).await?;
let server = WebServer::install(
SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 80),
ctx.clone(),
)
.await?;
tokio::time::sleep(Duration::from_secs(1)).await; // let the record state that I hate this
CHIME.play().await?;
ctx.shutdown
.subscribe()
.recv()
.await
.expect("context dropped");
server.shutdown().await;
Command::new("reboot")
.invoke(crate::ErrorKind::Unknown)
.await?;
} else if tokio::fs::metadata("/media/embassy/config/disk.guid")
.await
.is_err()
{
let ctx = SetupContext::init(cfg_path).await?;
let server = WebServer::setup(
SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 80),
ctx.clone(),
)
.await?;
tokio::time::sleep(Duration::from_secs(1)).await; // let the record state that I hate this
CHIME.play().await?;
ctx.shutdown
.subscribe()
.recv()
.await
.expect("context dropped");
server.shutdown().await;
tokio::task::yield_now().await;
if let Err(e) = Command::new("killall")
.arg("firefox-esr")
.invoke(ErrorKind::NotFound)
.await
{
tracing::error!("Failed to kill kiosk: {}", e);
tracing::debug!("{:?}", e);
}
} else {
let cfg = RpcContextConfig::load(cfg_path).await?;
let guid_string = tokio::fs::read_to_string("/media/embassy/config/disk.guid") // unique identifier for volume group - keeps track of the disk that goes with your embassy
.await?;
let guid = guid_string.trim();
let requires_reboot = crate::disk::main::import(
guid,
cfg.datadir(),
if tokio::fs::metadata(REPAIR_DISK_PATH).await.is_ok() {
RepairStrategy::Aggressive
} else {
RepairStrategy::Preen
},
if guid.ends_with("_UNENC") {
None
} else {
Some(DEFAULT_PASSWORD)
},
)
.await?;
if tokio::fs::metadata(REPAIR_DISK_PATH).await.is_ok() {
tokio::fs::remove_file(REPAIR_DISK_PATH)
.await
.with_ctx(|_| (crate::ErrorKind::Filesystem, REPAIR_DISK_PATH))?;
}
if requires_reboot.0 {
crate::disk::main::export(guid, cfg.datadir()).await?;
Command::new("reboot")
.invoke(crate::ErrorKind::Unknown)
.await?;
}
tracing::info!("Loaded Disk");
crate::init::init(&cfg).await?;
}
Ok(None)
}
async fn run_script_if_exists<P: AsRef<Path>>(path: P) {
let script = path.as_ref();
if script.exists() {
match Command::new("/bin/bash").arg(script).spawn() {
Ok(mut c) => {
if let Err(e) = c.wait().await {
tracing::error!("Error Running {}: {}", script.display(), e);
tracing::debug!("{:?}", e);
}
}
Err(e) => {
tracing::error!("Error Running {}: {}", script.display(), e);
tracing::debug!("{:?}", e);
}
}
}
}
#[instrument(skip_all)]
async fn inner_main(cfg_path: Option<PathBuf>) -> Result<Option<Shutdown>, Error> {
if &*PLATFORM == "raspberrypi" && tokio::fs::metadata(STANDBY_MODE_PATH).await.is_ok() {
tokio::fs::remove_file(STANDBY_MODE_PATH).await?;
Command::new("sync").invoke(ErrorKind::Filesystem).await?;
crate::sound::SHUTDOWN.play().await?;
futures::future::pending::<()>().await;
}
crate::sound::BEP.play().await?;
run_script_if_exists("/media/embassy/config/preinit.sh").await;
let res = match setup_or_init(cfg_path.clone()).await {
Err(e) => {
async move {
tracing::error!("{}", e.source);
tracing::debug!("{}", e.source);
crate::sound::BEETHOVEN.play().await?;
let ctx = DiagnosticContext::init(
cfg_path,
if tokio::fs::metadata("/media/embassy/config/disk.guid")
.await
.is_ok()
{
Some(Arc::new(
tokio::fs::read_to_string("/media/embassy/config/disk.guid") // unique identifier for volume group - keeps track of the disk that goes with your embassy
.await?
.trim()
.to_owned(),
))
} else {
None
},
e,
)
.await?;
let server = WebServer::diagnostic(
SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 80),
ctx.clone(),
)
.await?;
let shutdown = ctx.shutdown.subscribe().recv().await.unwrap();
server.shutdown().await;
Ok(shutdown)
}
.await
}
Ok(s) => Ok(s),
};
run_script_if_exists("/media/embassy/config/postinit.sh").await;
res
}
pub fn main() {
let matches = clap::App::new("start-init")
.arg(
clap::Arg::with_name("config")
.short('c')
.long("config")
.takes_value(true),
)
.get_matches();
let cfg_path = matches.value_of("config").map(|p| Path::new(p).to_owned());
let res = {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("failed to initialize runtime");
rt.block_on(inner_main(cfg_path))
};
match res {
Ok(Some(shutdown)) => shutdown.execute(),
Ok(None) => (),
Err(e) => {
eprintln!("{}", e.source);
tracing::debug!("{:?}", e.source);
drop(e.source);
std::process::exit(e.kind as i32)
}
}
}

View File

@@ -0,0 +1,61 @@
use rpc_toolkit::run_cli;
use rpc_toolkit::yajrc::RpcError;
use serde_json::Value;
use crate::context::SdkContext;
use crate::util::logger::EmbassyLogger;
use crate::version::{Current, VersionT};
use crate::Error;
lazy_static::lazy_static! {
static ref VERSION_STRING: String = Current::new().semver().to_string();
}
fn inner_main() -> Result<(), Error> {
run_cli!({
command: crate::portable_api,
app: app => app
.name("StartOS SDK")
.version(&**VERSION_STRING)
.arg(
clap::Arg::with_name("config")
.short('c')
.long("config")
.takes_value(true),
),
context: matches => {
if let Err(_) = std::env::var("RUST_LOG") {
std::env::set_var("RUST_LOG", "embassy=warn,js_engine=warn");
}
EmbassyLogger::init();
SdkContext::init(matches)?
},
exit: |e: RpcError| {
match e.data {
Some(Value::String(s)) => eprintln!("{}: {}", e.message, s),
Some(Value::Object(o)) => if let Some(Value::String(s)) = o.get("details") {
eprintln!("{}: {}", e.message, s);
if let Some(Value::String(s)) = o.get("debug") {
tracing::debug!("{}", s)
}
}
Some(a) => eprintln!("{}: {}", e.message, a),
None => eprintln!("{}", e.message),
}
std::process::exit(e.code);
}
});
Ok(())
}
pub fn main() {
match inner_main() {
Ok(_) => (),
Err(e) => {
eprintln!("{}", e.source);
tracing::debug!("{:?}", e.source);
drop(e.source);
std::process::exit(e.kind as i32)
}
}
}

View File

@@ -0,0 +1,187 @@
use std::net::{Ipv6Addr, SocketAddr};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use color_eyre::eyre::eyre;
use futures::{FutureExt, TryFutureExt};
use tokio::signal::unix::signal;
use tracing::instrument;
use crate::context::{DiagnosticContext, RpcContext};
use crate::net::web_server::WebServer;
use crate::shutdown::Shutdown;
use crate::system::launch_metrics_task;
use crate::util::logger::EmbassyLogger;
use crate::{Error, ErrorKind, ResultExt};
#[instrument(skip_all)]
async fn inner_main(cfg_path: Option<PathBuf>) -> Result<Option<Shutdown>, Error> {
let (rpc_ctx, server, shutdown) = async {
let rpc_ctx = RpcContext::init(
cfg_path,
Arc::new(
tokio::fs::read_to_string("/media/embassy/config/disk.guid") // unique identifier for volume group - keeps track of the disk that goes with your embassy
.await?
.trim()
.to_owned(),
),
)
.await?;
crate::hostname::sync_hostname(&rpc_ctx.account.read().await.hostname).await?;
let server = WebServer::main(
SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 80),
rpc_ctx.clone(),
)
.await?;
let mut shutdown_recv = rpc_ctx.shutdown.subscribe();
let sig_handler_ctx = rpc_ctx.clone();
let sig_handler = tokio::spawn(async move {
use tokio::signal::unix::SignalKind;
futures::future::select_all(
[
SignalKind::interrupt(),
SignalKind::quit(),
SignalKind::terminate(),
]
.iter()
.map(|s| {
async move {
signal(*s)
.unwrap_or_else(|_| panic!("register {:?} handler", s))
.recv()
.await
}
.boxed()
}),
)
.await;
sig_handler_ctx
.shutdown
.send(None)
.map_err(|_| ())
.expect("send shutdown signal");
});
let metrics_ctx = rpc_ctx.clone();
let metrics_task = tokio::spawn(async move {
launch_metrics_task(&metrics_ctx.metrics_cache, || {
metrics_ctx.shutdown.subscribe()
})
.await
});
crate::sound::CHIME.play().await?;
metrics_task
.map_err(|e| {
Error::new(
eyre!("{}", e).wrap_err("Metrics daemon panicked!"),
ErrorKind::Unknown,
)
})
.map_ok(|_| tracing::debug!("Metrics daemon Shutdown"))
.await?;
let shutdown = shutdown_recv
.recv()
.await
.with_kind(crate::ErrorKind::Unknown)?;
sig_handler.abort();
Ok::<_, Error>((rpc_ctx, server, shutdown))
}
.await?;
server.shutdown().await;
rpc_ctx.shutdown().await?;
tracing::info!("RPC Context is dropped");
Ok(shutdown)
}
pub fn main() {
EmbassyLogger::init();
if !Path::new("/run/embassy/initialized").exists() {
super::start_init::main();
std::fs::write("/run/embassy/initialized", "").unwrap();
}
let matches = clap::App::new("startd")
.arg(
clap::Arg::with_name("config")
.short('c')
.long("config")
.takes_value(true),
)
.get_matches();
let cfg_path = matches.value_of("config").map(|p| Path::new(p).to_owned());
let res = {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("failed to initialize runtime");
rt.block_on(async {
match inner_main(cfg_path.clone()).await {
Ok(a) => Ok(a),
Err(e) => {
async {
tracing::error!("{}", e.source);
tracing::debug!("{:?}", e.source);
crate::sound::BEETHOVEN.play().await?;
let ctx = DiagnosticContext::init(
cfg_path,
if tokio::fs::metadata("/media/embassy/config/disk.guid")
.await
.is_ok()
{
Some(Arc::new(
tokio::fs::read_to_string("/media/embassy/config/disk.guid") // unique identifier for volume group - keeps track of the disk that goes with your embassy
.await?
.trim()
.to_owned(),
))
} else {
None
},
e,
)
.await?;
let server = WebServer::diagnostic(
SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 80),
ctx.clone(),
)
.await?;
let mut shutdown = ctx.shutdown.subscribe();
let shutdown =
shutdown.recv().await.with_kind(crate::ErrorKind::Unknown)?;
server.shutdown().await;
Ok::<_, Error>(shutdown)
}
.await
}
}
})
};
match res {
Ok(None) => (),
Ok(Some(s)) => s.execute(),
Err(e) => {
eprintln!("{}", e.source);
tracing::debug!("{:?}", e.source);
drop(e.source);
std::process::exit(e.kind as i32)
}
}
}

View File

@@ -0,0 +1,116 @@
use std::collections::{BTreeMap, BTreeSet};
use color_eyre::eyre::eyre;
use models::ImageId;
use patch_db::HasModel;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use super::{Config, ConfigSpec};
use crate::context::RpcContext;
use crate::dependencies::Dependencies;
use crate::prelude::*;
use crate::procedure::docker::DockerContainers;
use crate::procedure::{PackageProcedure, ProcedureName};
use crate::s9pk::manifest::PackageId;
use crate::status::health_check::HealthCheckId;
use crate::util::Version;
use crate::volume::Volumes;
use crate::{Error, ResultExt};
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct ConfigRes {
pub config: Option<Config>,
pub spec: ConfigSpec,
}
#[derive(Clone, Debug, Deserialize, Serialize, HasModel)]
#[model = "Model<Self>"]
pub struct ConfigActions {
pub get: PackageProcedure,
pub set: PackageProcedure,
}
impl ConfigActions {
#[instrument(skip_all)]
pub fn validate(
&self,
_container: &Option<DockerContainers>,
eos_version: &Version,
volumes: &Volumes,
image_ids: &BTreeSet<ImageId>,
) -> Result<(), Error> {
self.get
.validate(eos_version, volumes, image_ids, true)
.with_ctx(|_| (crate::ErrorKind::ValidateS9pk, "Config Get"))?;
self.set
.validate(eos_version, volumes, image_ids, true)
.with_ctx(|_| (crate::ErrorKind::ValidateS9pk, "Config Set"))?;
Ok(())
}
#[instrument(skip_all)]
pub async fn get(
&self,
ctx: &RpcContext,
pkg_id: &PackageId,
pkg_version: &Version,
volumes: &Volumes,
) -> Result<ConfigRes, Error> {
self.get
.execute(
ctx,
pkg_id,
pkg_version,
ProcedureName::GetConfig,
volumes,
None::<()>,
None,
)
.await
.and_then(|res| {
res.map_err(|e| Error::new(eyre!("{}", e.1), crate::ErrorKind::ConfigGen))
})
}
#[instrument(skip_all)]
pub async fn set(
&self,
ctx: &RpcContext,
pkg_id: &PackageId,
pkg_version: &Version,
dependencies: &Dependencies,
volumes: &Volumes,
input: &Config,
) -> Result<SetResult, Error> {
let res: SetResult = self
.set
.execute(
ctx,
pkg_id,
pkg_version,
ProcedureName::SetConfig,
volumes,
Some(input),
None,
)
.await
.and_then(|res| {
res.map_err(|e| {
Error::new(eyre!("{}", e.1), crate::ErrorKind::ConfigRulesViolation)
})
})?;
Ok(SetResult {
depends_on: res
.depends_on
.into_iter()
.filter(|(pkg, _)| dependencies.0.contains_key(pkg))
.collect(),
})
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct SetResult {
pub depends_on: BTreeMap<PackageId, BTreeSet<HealthCheckId>>,
}

View File

@@ -0,0 +1,287 @@
use std::collections::BTreeMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use color_eyre::eyre::eyre;
use indexmap::IndexSet;
use itertools::Itertools;
use models::{ErrorKind, OptionExt};
use patch_db::value::InternedString;
use patch_db::Value;
use regex::Regex;
use rpc_toolkit::command;
use tracing::instrument;
use crate::context::RpcContext;
use crate::prelude::*;
use crate::s9pk::manifest::PackageId;
use crate::util::display_none;
use crate::util::serde::{display_serializable, parse_stdin_deserializable, IoFormat};
use crate::Error;
pub mod action;
pub mod spec;
pub mod util;
pub use spec::{ConfigSpec, Defaultable};
use util::NumRange;
use self::action::ConfigRes;
use self::spec::ValueSpecPointer;
pub type Config = patch_db::value::InOMap<InternedString, Value>;
pub trait TypeOf {
fn type_of(&self) -> &'static str;
}
impl TypeOf for Value {
fn type_of(&self) -> &'static str {
match self {
Value::Array(_) => "list",
Value::Bool(_) => "boolean",
Value::Null => "null",
Value::Number(_) => "number",
Value::Object(_) => "object",
Value::String(_) => "string",
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ConfigurationError {
#[error("Timeout Error")]
TimeoutError(#[from] TimeoutError),
#[error("No Match: {0}")]
NoMatch(#[from] NoMatchWithPath),
#[error("System Error: {0}")]
SystemError(Error),
#[error("Permission Denied: {0}")]
PermissionDenied(ValueSpecPointer),
}
impl From<ConfigurationError> for Error {
fn from(err: ConfigurationError) -> Self {
let kind = match &err {
ConfigurationError::SystemError(e) => e.kind,
_ => crate::ErrorKind::ConfigGen,
};
crate::Error::new(err, kind)
}
}
#[derive(Clone, Copy, Debug, thiserror::Error)]
#[error("Timeout Error")]
pub struct TimeoutError;
#[derive(Clone, Debug, thiserror::Error)]
pub struct NoMatchWithPath {
pub path: Vec<InternedString>,
pub error: MatchError,
}
impl NoMatchWithPath {
pub fn new(error: MatchError) -> Self {
NoMatchWithPath {
path: Vec::new(),
error,
}
}
pub fn prepend(mut self, seg: InternedString) -> Self {
self.path.push(seg);
self
}
}
impl std::fmt::Display for NoMatchWithPath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}: {}", self.path.iter().rev().join("."), self.error)
}
}
impl From<NoMatchWithPath> for Error {
fn from(e: NoMatchWithPath) -> Self {
ConfigurationError::from(e).into()
}
}
#[derive(Clone, Debug, thiserror::Error)]
pub enum MatchError {
#[error("String {0:?} Does Not Match Pattern {1}")]
Pattern(Arc<String>, Regex),
#[error("String {0:?} Is Not In Enum {1:?}")]
Enum(Arc<String>, IndexSet<String>),
#[error("Field Is Not Nullable")]
NotNullable,
#[error("Length Mismatch: expected {0}, actual: {1}")]
LengthMismatch(NumRange<usize>, usize),
#[error("Invalid Type: expected {0}, actual: {1}")]
InvalidType(&'static str, &'static str),
#[error("Number Out Of Range: expected {0}, actual: {1}")]
OutOfRange(NumRange<f64>, f64),
#[error("Number Is Not Integral: {0}")]
NonIntegral(f64),
#[error("Variant {0:?} Is Not In Union {1:?}")]
Union(Arc<String>, IndexSet<String>),
#[error("Variant Is Missing Tag {0:?}")]
MissingTag(InternedString),
#[error("Property {0:?} Of Variant {1:?} Conflicts With Union Tag")]
PropertyMatchesUnionTag(InternedString, String),
#[error("Name of Property {0:?} Conflicts With Map Tag Name")]
PropertyNameMatchesMapTag(String),
#[error("Pointer Is Invalid: {0}")]
InvalidPointer(spec::ValueSpecPointer),
#[error("Object Key Is Invalid: {0}")]
InvalidKey(String),
#[error("Value In List Is Not Unique")]
ListUniquenessViolation,
}
#[command(rename = "config-spec", cli_only, blocking, display(display_none))]
pub fn verify_spec(#[arg] path: PathBuf) -> Result<(), Error> {
let mut file = std::fs::File::open(&path)?;
let format = match path.extension().and_then(|s| s.to_str()) {
Some("yaml") | Some("yml") => IoFormat::Yaml,
Some("json") => IoFormat::Json,
Some("toml") => IoFormat::Toml,
Some("cbor") => IoFormat::Cbor,
_ => {
return Err(Error::new(
eyre!("Unknown file format. Expected one of yaml, json, toml, cbor."),
crate::ErrorKind::Deserialization,
));
}
};
let _: ConfigSpec = format.from_reader(&mut file)?;
Ok(())
}
#[command(subcommands(get, set))]
pub fn config(#[arg] id: PackageId) -> Result<PackageId, Error> {
Ok(id)
}
#[command(display(display_serializable))]
#[instrument(skip_all)]
pub async fn get(
#[context] ctx: RpcContext,
#[parent_data] id: PackageId,
#[allow(unused_variables)]
#[arg(long = "format")]
format: Option<IoFormat>,
) -> Result<ConfigRes, Error> {
let db = ctx.db.peek().await;
let manifest = db
.as_package_data()
.as_idx(&id)
.or_not_found(&id)?
.as_installed()
.or_not_found(&id)?
.as_manifest();
let action = manifest
.as_config()
.de()?
.ok_or_else(|| Error::new(eyre!("{} has no config", id), crate::ErrorKind::NotFound))?;
let volumes = manifest.as_volumes().de()?;
let version = manifest.as_version().de()?;
action.get(&ctx, &id, &version, &volumes).await
}
#[command(
subcommands(self(set_impl(async, context(RpcContext))), set_dry),
display(display_none),
metadata(sync_db = true)
)]
#[instrument(skip_all)]
pub fn set(
#[parent_data] id: PackageId,
#[allow(unused_variables)]
#[arg(long = "format")]
format: Option<IoFormat>,
#[arg(long = "timeout")] timeout: Option<crate::util::serde::Duration>,
#[arg(stdin, parse(parse_stdin_deserializable))] config: Option<Config>,
) -> Result<(PackageId, Option<Config>, Option<Duration>), Error> {
Ok((id, config, timeout.map(|d| *d)))
}
#[command(rename = "dry", display(display_serializable))]
#[instrument(skip_all)]
pub async fn set_dry(
#[context] ctx: RpcContext,
#[parent_data] (id, config, timeout): (PackageId, Option<Config>, Option<Duration>),
) -> Result<BTreeMap<PackageId, String>, Error> {
let breakages = BTreeMap::new();
let overrides = Default::default();
let configure_context = ConfigureContext {
breakages,
timeout,
config,
dry_run: true,
overrides,
};
let breakages = configure(&ctx, &id, configure_context).await?;
Ok(breakages)
}
pub struct ConfigureContext {
pub breakages: BTreeMap<PackageId, String>,
pub timeout: Option<Duration>,
pub config: Option<Config>,
pub overrides: BTreeMap<PackageId, Config>,
pub dry_run: bool,
}
#[instrument(skip_all)]
pub async fn set_impl(
ctx: RpcContext,
(id, config, timeout): (PackageId, Option<Config>, Option<Duration>),
) -> Result<(), Error> {
let breakages = BTreeMap::new();
let overrides = Default::default();
let configure_context = ConfigureContext {
breakages,
timeout,
config,
dry_run: false,
overrides,
};
configure(&ctx, &id, configure_context).await?;
Ok(())
}
#[instrument(skip_all)]
pub async fn configure(
ctx: &RpcContext,
id: &PackageId,
configure_context: ConfigureContext,
) -> Result<BTreeMap<PackageId, String>, Error> {
let db = ctx.db.peek().await;
let package = db
.as_package_data()
.as_idx(id)
.or_not_found(&id)?
.as_installed()
.or_not_found(&id)?;
let version = package.as_manifest().as_version().de()?;
ctx.managers
.get(&(id.clone(), version.clone()))
.await
.ok_or_else(|| {
Error::new(
eyre!("There is no manager running for {id:?} and {version:?}"),
ErrorKind::Unknown,
)
})?
.configure(configure_context)
.await
}
macro_rules! not_found {
($x:expr) => {
crate::Error::new(
color_eyre::eyre::eyre!("Could not find {} at {}:{}", $x, module_path!(), line!()),
crate::ErrorKind::Incoherent,
)
};
}
pub(crate) use not_found;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,406 @@
use std::borrow::Cow;
use std::ops::{Bound, RangeBounds, RangeInclusive};
use patch_db::Value;
use rand::distributions::Distribution;
use rand::Rng;
use super::Config;
pub const STATIC_NULL: Value = Value::Null;
#[derive(Clone, Debug)]
pub struct CharSet(pub Vec<(RangeInclusive<char>, usize)>, usize);
impl CharSet {
pub fn contains(&self, c: &char) -> bool {
self.0.iter().any(|r| r.0.contains(c))
}
pub fn gen<R: Rng>(&self, rng: &mut R) -> char {
let mut idx = rng.gen_range(0..self.1);
for r in &self.0 {
if idx < r.1 {
return std::convert::TryFrom::try_from(
rand::distributions::Uniform::new_inclusive(
u32::from(*r.0.start()),
u32::from(*r.0.end()),
)
.sample(rng),
)
.unwrap();
} else {
idx -= r.1;
}
}
unreachable!()
}
}
impl Default for CharSet {
fn default() -> Self {
CharSet(vec![('!'..='~', 94)], 94)
}
}
impl<'de> serde::de::Deserialize<'de> for CharSet {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let mut res = Vec::new();
let mut len = 0;
let mut a: Option<char> = None;
let mut b: Option<char> = None;
let mut in_range = false;
for c in s.chars() {
match c {
',' => match (a, b, in_range) {
(Some(start), Some(end), _) => {
if !end.is_ascii() {
return Err(serde::de::Error::custom("Invalid Character"));
}
if start >= end {
return Err(serde::de::Error::custom("Invalid Bounds"));
}
let l = u32::from(end) - u32::from(start) + 1;
res.push((start..=end, l as usize));
len += l as usize;
a = None;
b = None;
in_range = false;
}
(Some(start), None, false) => {
len += 1;
res.push((start..=start, 1));
a = None;
}
(Some(_), None, true) => {
b = Some(',');
}
(None, None, false) => {
a = Some(',');
}
_ => {
return Err(serde::de::Error::custom("Syntax Error"));
}
},
'-' => {
if a.is_none() {
a = Some('-');
} else if !in_range {
in_range = true;
} else if b.is_none() {
b = Some('-')
} else {
return Err(serde::de::Error::custom("Syntax Error"));
}
}
_ => {
if a.is_none() {
a = Some(c);
} else if in_range && b.is_none() {
b = Some(c);
} else {
return Err(serde::de::Error::custom("Syntax Error"));
}
}
}
}
match (a, b) {
(Some(start), Some(end)) => {
if !end.is_ascii() {
return Err(serde::de::Error::custom("Invalid Character"));
}
if start >= end {
return Err(serde::de::Error::custom("Invalid Bounds"));
}
let l = u32::from(end) - u32::from(start) + 1;
res.push((start..=end, l as usize));
len += l as usize;
}
(Some(c), None) => {
len += 1;
res.push((c..=c, 1));
}
_ => (),
}
Ok(CharSet(res, len))
}
}
impl serde::ser::Serialize for CharSet {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::ser::Serializer,
{
<&str>::serialize(
&self
.0
.iter()
.map(|r| match r.1 {
1 => format!("{}", r.0.start()),
_ => format!("{}-{}", r.0.start(), r.0.end()),
})
.collect::<Vec<_>>()
.join(",")
.as_str(),
serializer,
)
}
}
pub trait MergeWith {
fn merge_with(&mut self, other: &serde_json::Value);
}
impl MergeWith for serde_json::Value {
fn merge_with(&mut self, other: &serde_json::Value) {
use serde_json::Value::Object;
if let (Object(orig), Object(ref other)) = (self, other) {
for (key, val) in other.into_iter() {
match (orig.get_mut(key), val) {
(Some(new_orig @ Object(_)), other @ Object(_)) => {
new_orig.merge_with(other);
}
(None, _) => {
orig.insert(key.clone(), val.clone());
}
_ => (),
}
}
}
}
}
#[test]
fn merge_with_tests() {
use serde_json::json;
let mut a = json!(
{"a": 1, "c": {"d": "123"}, "i": [1,2,3], "j": {}, "k":[1,2,3], "l": "test"}
);
a.merge_with(
&json!({"a":"a", "b": "b", "c":{"d":"d", "e":"e"}, "f":{"g":"g"}, "h": [1,2,3], "i":"i", "j":[1,2,3], "k":{}}),
);
assert_eq!(
a,
json!({"a": 1, "c": {"d": "123", "e":"e"}, "b":"b", "f": {"g":"g"}, "h":[1,2,3], "i":[1,2,3], "j": {}, "k":[1,2,3], "l": "test"})
)
}
pub mod serde_regex {
use regex::Regex;
use serde::*;
pub fn serialize<S>(regex: &Regex, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
<&str>::serialize(&regex.as_str(), serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Regex, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Regex::new(&s).map_err(|e| de::Error::custom(e))
}
}
#[derive(Clone, Debug)]
pub struct NumRange<T: std::str::FromStr + std::fmt::Display + std::cmp::PartialOrd>(
pub (Bound<T>, Bound<T>),
);
impl<T> std::ops::Deref for NumRange<T>
where
T: std::str::FromStr + std::fmt::Display + std::cmp::PartialOrd,
{
type Target = (Bound<T>, Bound<T>);
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'de, T> serde::de::Deserialize<'de> for NumRange<T>
where
T: std::str::FromStr + std::fmt::Display + std::cmp::PartialOrd,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let mut split = s.split(",");
let start = split
.next()
.map(|s| match s.get(..1) {
Some("(") => match s.get(1..2) {
Some("*") => Ok(Bound::Unbounded),
_ => s[1..]
.trim()
.parse()
.map(Bound::Excluded)
.map_err(|e| serde::de::Error::custom(e)),
},
Some("[") => s[1..]
.trim()
.parse()
.map(Bound::Included)
.map_err(|e| serde::de::Error::custom(e)),
_ => Err(serde::de::Error::custom(format!(
"Could not parse left bound: {}",
s
))),
})
.transpose()?
.unwrap();
let end = split
.next()
.map(|s| match s.get(s.len() - 1..) {
Some(")") => match s.get(s.len() - 2..s.len() - 1) {
Some("*") => Ok(Bound::Unbounded),
_ => s[..s.len() - 1]
.trim()
.parse()
.map(Bound::Excluded)
.map_err(|e| serde::de::Error::custom(e)),
},
Some("]") => s[..s.len() - 1]
.trim()
.parse()
.map(Bound::Included)
.map_err(|e| serde::de::Error::custom(e)),
_ => Err(serde::de::Error::custom(format!(
"Could not parse right bound: {}",
s
))),
})
.transpose()?
.unwrap_or(Bound::Unbounded);
Ok(NumRange((start, end)))
}
}
impl<T> std::fmt::Display for NumRange<T>
where
T: std::str::FromStr + std::fmt::Display + std::cmp::PartialOrd,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.start_bound() {
Bound::Excluded(n) => write!(f, "({},", n)?,
Bound::Included(n) => write!(f, "[{},", n)?,
Bound::Unbounded => write!(f, "(*,")?,
};
match self.end_bound() {
Bound::Excluded(n) => write!(f, "{})", n),
Bound::Included(n) => write!(f, "{}]", n),
Bound::Unbounded => write!(f, "*)"),
}
}
}
impl<T> serde::ser::Serialize for NumRange<T>
where
T: std::str::FromStr + std::fmt::Display + std::cmp::PartialOrd,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::ser::Serializer,
{
<&str>::serialize(&format!("{}", self).as_str(), serializer)
}
}
#[derive(Clone, Debug)]
pub enum UniqueBy {
Any(Vec<UniqueBy>),
All(Vec<UniqueBy>),
Exactly(String),
NotUnique,
}
impl UniqueBy {
pub fn eq(&self, lhs: &Config, rhs: &Config) -> bool {
match self {
UniqueBy::Any(any) => any.iter().any(|u| u.eq(lhs, rhs)),
UniqueBy::All(all) => all.iter().all(|u| u.eq(lhs, rhs)),
UniqueBy::Exactly(key) => lhs.get(&**key) == rhs.get(&**key),
UniqueBy::NotUnique => false,
}
}
}
impl Default for UniqueBy {
fn default() -> Self {
UniqueBy::NotUnique
}
}
impl<'de> serde::de::Deserialize<'de> for UniqueBy {
fn deserialize<D: serde::de::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = UniqueBy;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a key, an \"any\" object, or an \"all\" object")
}
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
Ok(UniqueBy::Exactly(v.to_owned()))
}
fn visit_string<E: serde::de::Error>(self, v: String) -> Result<Self::Value, E> {
Ok(UniqueBy::Exactly(v))
}
fn visit_map<A: serde::de::MapAccess<'de>>(
self,
mut map: A,
) -> Result<Self::Value, A::Error> {
let mut variant = None;
while let Some(key) = map.next_key::<Cow<str>>()? {
match key.as_ref() {
"any" => {
return Ok(UniqueBy::Any(map.next_value()?));
}
"all" => {
return Ok(UniqueBy::All(map.next_value()?));
}
_ => {
variant = Some(key);
}
}
}
Err(serde::de::Error::unknown_variant(
variant.unwrap_or_default().as_ref(),
&["any", "all"],
))
}
fn visit_unit<E: serde::de::Error>(self) -> Result<Self::Value, E> {
Ok(UniqueBy::NotUnique)
}
fn visit_none<E: serde::de::Error>(self) -> Result<Self::Value, E> {
Ok(UniqueBy::NotUnique)
}
}
deserializer.deserialize_any(Visitor)
}
}
impl serde::ser::Serialize for UniqueBy {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::ser::Serializer,
{
use serde::ser::SerializeMap;
match self {
UniqueBy::Any(any) => {
let mut map = serializer.serialize_map(Some(1))?;
map.serialize_key("any")?;
map.serialize_value(any)?;
map.end()
}
UniqueBy::All(all) => {
let mut map = serializer.serialize_map(Some(1))?;
map.serialize_key("all")?;
map.serialize_value(all)?;
map.end()
}
UniqueBy::Exactly(key) => serializer.serialize_str(key),
UniqueBy::NotUnique => serializer.serialize_unit(),
}
}
}

View File

@@ -0,0 +1,185 @@
use std::fs::File;
use std::io::BufReader;
use std::net::Ipv4Addr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use clap::ArgMatches;
use color_eyre::eyre::eyre;
use cookie_store::{CookieStore, RawCookie};
use josekit::jwk::Jwk;
use reqwest::Proxy;
use reqwest_cookie_store::CookieStoreMutex;
use rpc_toolkit::reqwest::{Client, Url};
use rpc_toolkit::url::Host;
use rpc_toolkit::Context;
use serde::Deserialize;
use tracing::instrument;
use super::setup::CURRENT_SECRET;
use crate::middleware::auth::LOCAL_AUTH_COOKIE_PATH;
use crate::util::config::{load_config_from_paths, local_config_path};
use crate::ResultExt;
#[derive(Debug, Default, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct CliContextConfig {
pub host: Option<Url>,
#[serde(deserialize_with = "crate::util::serde::deserialize_from_str_opt")]
#[serde(default)]
pub proxy: Option<Url>,
pub cookie_path: Option<PathBuf>,
}
#[derive(Debug)]
pub struct CliContextSeed {
pub base_url: Url,
pub rpc_url: Url,
pub client: Client,
pub cookie_store: Arc<CookieStoreMutex>,
pub cookie_path: PathBuf,
}
impl Drop for CliContextSeed {
fn drop(&mut self) {
let tmp = format!("{}.tmp", self.cookie_path.display());
let parent_dir = self.cookie_path.parent().unwrap_or(Path::new("/"));
if !parent_dir.exists() {
std::fs::create_dir_all(&parent_dir).unwrap();
}
let mut writer = fd_lock_rs::FdLock::lock(
File::create(&tmp).unwrap(),
fd_lock_rs::LockType::Exclusive,
true,
)
.unwrap();
let mut store = self.cookie_store.lock().unwrap();
store.remove("localhost", "", "local");
store.save_json(&mut *writer).unwrap();
writer.sync_all().unwrap();
std::fs::rename(tmp, &self.cookie_path).unwrap();
}
}
const DEFAULT_HOST: Host<&'static str> = Host::Ipv4(Ipv4Addr::new(127, 0, 0, 1));
const DEFAULT_PORT: u16 = 5959;
#[derive(Debug, Clone)]
pub struct CliContext(Arc<CliContextSeed>);
impl CliContext {
/// BLOCKING
#[instrument(skip_all)]
pub fn init(matches: &ArgMatches) -> Result<Self, crate::Error> {
let local_config_path = local_config_path();
let base: CliContextConfig = load_config_from_paths(
matches
.values_of("config")
.into_iter()
.flatten()
.map(|p| Path::new(p))
.chain(local_config_path.as_deref().into_iter())
.chain(std::iter::once(Path::new(crate::util::config::CONFIG_PATH))),
)?;
let mut url = if let Some(host) = matches.value_of("host") {
host.parse()?
} else if let Some(host) = base.host {
host
} else {
"http://localhost".parse()?
};
let proxy = if let Some(proxy) = matches.value_of("proxy") {
Some(proxy.parse()?)
} else {
base.proxy
};
let cookie_path = base.cookie_path.unwrap_or_else(|| {
local_config_path
.as_deref()
.unwrap_or_else(|| Path::new(crate::util::config::CONFIG_PATH))
.parent()
.unwrap_or(Path::new("/"))
.join(".cookies.json")
});
let cookie_store = Arc::new(CookieStoreMutex::new({
let mut store = if cookie_path.exists() {
CookieStore::load_json(BufReader::new(File::open(&cookie_path)?))
.map_err(|e| eyre!("{}", e))
.with_kind(crate::ErrorKind::Deserialization)?
} else {
CookieStore::default()
};
if let Ok(local) = std::fs::read_to_string(LOCAL_AUTH_COOKIE_PATH) {
store
.insert_raw(
&RawCookie::new("local", local),
&"http://localhost".parse()?,
)
.with_kind(crate::ErrorKind::Network)?;
}
store
}));
Ok(CliContext(Arc::new(CliContextSeed {
base_url: url.clone(),
rpc_url: {
url.path_segments_mut()
.map_err(|_| eyre!("Url cannot be base"))
.with_kind(crate::ErrorKind::ParseUrl)?
.push("rpc")
.push("v1");
url
},
client: {
let mut builder = Client::builder().cookie_provider(cookie_store.clone());
if let Some(proxy) = proxy {
builder =
builder.proxy(Proxy::all(proxy).with_kind(crate::ErrorKind::ParseUrl)?)
}
builder.build().expect("cannot fail")
},
cookie_store,
cookie_path,
})))
}
}
impl AsRef<Jwk> for CliContext {
fn as_ref(&self) -> &Jwk {
&*CURRENT_SECRET
}
}
impl std::ops::Deref for CliContext {
type Target = CliContextSeed;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
impl Context for CliContext {
fn protocol(&self) -> &str {
self.0.base_url.scheme()
}
fn host(&self) -> Host<&str> {
self.0.base_url.host().unwrap_or(DEFAULT_HOST)
}
fn port(&self) -> u16 {
self.0.base_url.port().unwrap_or(DEFAULT_PORT)
}
fn path(&self) -> &str {
self.0.rpc_url.path()
}
fn url(&self) -> Url {
self.0.rpc_url.clone()
}
fn client(&self) -> &Client {
&self.0.client
}
}
/// When we had an empty proxy the system wasn't working like it used to, which allowed empty proxy
#[test]
fn test_cli_proxy_empty() {
serde_yaml::from_str::<CliContextConfig>(
"
bind_rpc:
",
)
.unwrap();
}

View File

@@ -0,0 +1,83 @@
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::Context;
use serde::Deserialize;
use tokio::sync::broadcast::Sender;
use tracing::instrument;
use crate::shutdown::Shutdown;
use crate::util::config::load_config_from_paths;
use crate::Error;
#[derive(Debug, Default, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct DiagnosticContextConfig {
pub datadir: Option<PathBuf>,
}
impl DiagnosticContextConfig {
#[instrument(skip_all)]
pub async fn load<P: AsRef<Path> + Send + 'static>(path: Option<P>) -> Result<Self, Error> {
tokio::task::spawn_blocking(move || {
load_config_from_paths(
path.as_ref()
.into_iter()
.map(|p| p.as_ref())
.chain(std::iter::once(Path::new(
crate::util::config::DEVICE_CONFIG_PATH,
)))
.chain(std::iter::once(Path::new(crate::util::config::CONFIG_PATH))),
)
})
.await
.unwrap()
}
pub fn datadir(&self) -> &Path {
self.datadir
.as_deref()
.unwrap_or_else(|| Path::new("/embassy-data"))
}
}
pub struct DiagnosticContextSeed {
pub datadir: PathBuf,
pub shutdown: Sender<Option<Shutdown>>,
pub error: Arc<RpcError>,
pub disk_guid: Option<Arc<String>>,
}
#[derive(Clone)]
pub struct DiagnosticContext(Arc<DiagnosticContextSeed>);
impl DiagnosticContext {
#[instrument(skip_all)]
pub async fn init<P: AsRef<Path> + Send + 'static>(
path: Option<P>,
disk_guid: Option<Arc<String>>,
error: Error,
) -> Result<Self, Error> {
tracing::error!("Error: {}: Starting diagnostic UI", error);
tracing::debug!("{:?}", error);
let cfg = DiagnosticContextConfig::load(path).await?;
let (shutdown, _) = tokio::sync::broadcast::channel(1);
Ok(Self(Arc::new(DiagnosticContextSeed {
datadir: cfg.datadir().to_owned(),
shutdown,
disk_guid,
error: Arc::new(error.into()),
})))
}
}
impl Context for DiagnosticContext {}
impl Deref for DiagnosticContext {
type Target = DiagnosticContextSeed;
fn deref(&self) -> &Self::Target {
&*self.0
}
}

View File

@@ -0,0 +1,58 @@
use std::ops::Deref;
use std::path::Path;
use std::sync::Arc;
use rpc_toolkit::Context;
use serde::Deserialize;
use tokio::sync::broadcast::Sender;
use tracing::instrument;
use crate::net::utils::find_eth_iface;
use crate::util::config::load_config_from_paths;
use crate::Error;
#[derive(Debug, Default, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct InstallContextConfig {}
impl InstallContextConfig {
#[instrument(skip_all)]
pub async fn load<P: AsRef<Path> + Send + 'static>(path: Option<P>) -> Result<Self, Error> {
tokio::task::spawn_blocking(move || {
load_config_from_paths(
path.as_ref()
.into_iter()
.map(|p| p.as_ref())
.chain(std::iter::once(Path::new(crate::util::config::CONFIG_PATH))),
)
})
.await
.unwrap()
}
}
pub struct InstallContextSeed {
pub ethernet_interface: String,
pub shutdown: Sender<()>,
}
#[derive(Clone)]
pub struct InstallContext(Arc<InstallContextSeed>);
impl InstallContext {
#[instrument(skip_all)]
pub async fn init<P: AsRef<Path> + Send + 'static>(path: Option<P>) -> Result<Self, Error> {
let _cfg = InstallContextConfig::load(path.as_ref().map(|p| p.as_ref().to_owned())).await?;
let (shutdown, _) = tokio::sync::broadcast::channel(1);
Ok(Self(Arc::new(InstallContextSeed {
ethernet_interface: find_eth_iface().await?,
shutdown,
})))
}
}
impl Context for InstallContext {}
impl Deref for InstallContext {
type Target = InstallContextSeed;
fn deref(&self) -> &Self::Target {
&*self.0
}
}

View File

@@ -0,0 +1,44 @@
pub mod cli;
pub mod diagnostic;
pub mod install;
pub mod rpc;
pub mod sdk;
pub mod setup;
pub use cli::CliContext;
pub use diagnostic::DiagnosticContext;
pub use install::InstallContext;
pub use rpc::RpcContext;
pub use sdk::SdkContext;
pub use setup::SetupContext;
impl From<CliContext> for () {
fn from(_: CliContext) -> Self {
()
}
}
impl From<DiagnosticContext> for () {
fn from(_: DiagnosticContext) -> Self {
()
}
}
impl From<RpcContext> for () {
fn from(_: RpcContext) -> Self {
()
}
}
impl From<SdkContext> for () {
fn from(_: SdkContext) -> Self {
()
}
}
impl From<SetupContext> for () {
fn from(_: SetupContext) -> Self {
()
}
}
impl From<InstallContext> for () {
fn from(_: InstallContext) -> Self {
()
}
}

View File

@@ -0,0 +1,466 @@
use std::collections::BTreeMap;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use helpers::to_tmp_path;
use josekit::jwk::Jwk;
use patch_db::json_ptr::JsonPointer;
use patch_db::PatchDb;
use reqwest::{Client, Proxy, Url};
use rpc_toolkit::Context;
use serde::Deserialize;
use sqlx::postgres::PgConnectOptions;
use sqlx::PgPool;
use tokio::sync::{broadcast, oneshot, Mutex, RwLock};
use tokio::time::Instant;
use tracing::instrument;
use super::setup::CURRENT_SECRET;
use crate::account::AccountInfo;
use crate::core::rpc_continuations::{RequestGuid, RestHandler, RpcContinuation};
use crate::db::model::{CurrentDependents, Database, PackageDataEntryMatchModelRef};
use crate::db::prelude::PatchDbExt;
use crate::dependencies::compute_dependency_config_errs;
use crate::disk::OsPartitionInfo;
use crate::init::init_postgres;
use crate::install::cleanup::{cleanup_failed, uninstall};
use crate::manager::ManagerMap;
use crate::middleware::auth::HashSessionToken;
use crate::net::net_controller::NetController;
use crate::net::ssl::{root_ca_start_time, SslManager};
use crate::net::wifi::WpaCli;
use crate::notifications::NotificationManager;
use crate::shutdown::Shutdown;
use crate::status::MainStatus;
use crate::system::get_mem_info;
use crate::util::config::load_config_from_paths;
use crate::util::lshw::{lshw, LshwDevice};
use crate::{Error, ErrorKind, ResultExt};
#[derive(Debug, Default, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct RpcContextConfig {
pub wifi_interface: Option<String>,
pub ethernet_interface: String,
pub os_partitions: OsPartitionInfo,
pub migration_batch_rows: Option<usize>,
pub migration_prefetch_rows: Option<usize>,
pub bind_rpc: Option<SocketAddr>,
pub tor_control: Option<SocketAddr>,
pub tor_socks: Option<SocketAddr>,
pub dns_bind: Option<Vec<SocketAddr>>,
pub revision_cache_size: Option<usize>,
pub datadir: Option<PathBuf>,
pub log_server: Option<Url>,
}
impl RpcContextConfig {
pub async fn load<P: AsRef<Path> + Send + 'static>(path: Option<P>) -> Result<Self, Error> {
tokio::task::spawn_blocking(move || {
load_config_from_paths(
path.as_ref()
.into_iter()
.map(|p| p.as_ref())
.chain(std::iter::once(Path::new(
crate::util::config::DEVICE_CONFIG_PATH,
)))
.chain(std::iter::once(Path::new(crate::util::config::CONFIG_PATH))),
)
})
.await
.unwrap()
}
pub fn datadir(&self) -> &Path {
self.datadir
.as_deref()
.unwrap_or_else(|| Path::new("/embassy-data"))
}
pub async fn db(&self, account: &AccountInfo) -> Result<PatchDb, Error> {
let db_path = self.datadir().join("main").join("embassy.db");
let db = PatchDb::open(&db_path)
.await
.with_ctx(|_| (crate::ErrorKind::Filesystem, db_path.display().to_string()))?;
if !db.exists(&<JsonPointer>::default()).await {
db.put(&<JsonPointer>::default(), &Database::init(account))
.await?;
}
Ok(db)
}
#[instrument(skip_all)]
pub async fn secret_store(&self) -> Result<PgPool, Error> {
init_postgres(self.datadir()).await?;
let secret_store =
PgPool::connect_with(PgConnectOptions::new().database("secrets").username("root"))
.await?;
sqlx::migrate!()
.run(&secret_store)
.await
.with_kind(crate::ErrorKind::Database)?;
Ok(secret_store)
}
}
pub struct RpcContextSeed {
is_closed: AtomicBool,
pub os_partitions: OsPartitionInfo,
pub wifi_interface: Option<String>,
pub ethernet_interface: String,
pub datadir: PathBuf,
pub disk_guid: Arc<String>,
pub db: PatchDb,
pub secret_store: PgPool,
pub account: RwLock<AccountInfo>,
pub net_controller: Arc<NetController>,
pub managers: ManagerMap,
pub metrics_cache: RwLock<Option<crate::system::Metrics>>,
pub shutdown: broadcast::Sender<Option<Shutdown>>,
pub tor_socks: SocketAddr,
pub notification_manager: NotificationManager,
pub open_authed_websockets: Mutex<BTreeMap<HashSessionToken, Vec<oneshot::Sender<()>>>>,
pub rpc_stream_continuations: Mutex<BTreeMap<RequestGuid, RpcContinuation>>,
pub wifi_manager: Option<Arc<RwLock<WpaCli>>>,
pub current_secret: Arc<Jwk>,
pub client: Client,
pub hardware: Hardware,
pub start_time: Instant,
}
pub struct Hardware {
pub devices: Vec<LshwDevice>,
pub ram: u64,
}
#[derive(Clone)]
pub struct RpcContext(Arc<RpcContextSeed>);
impl RpcContext {
#[instrument(skip_all)]
pub async fn init<P: AsRef<Path> + Send + Sync + 'static>(
cfg_path: Option<P>,
disk_guid: Arc<String>,
) -> Result<Self, Error> {
let base = RpcContextConfig::load(cfg_path).await?;
tracing::info!("Loaded Config");
let tor_proxy = base.tor_socks.unwrap_or(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(127, 0, 0, 1),
9050,
)));
let (shutdown, _) = tokio::sync::broadcast::channel(1);
let secret_store = base.secret_store().await?;
tracing::info!("Opened Pg DB");
let account = AccountInfo::load(&secret_store).await?;
let db = base.db(&account).await?;
tracing::info!("Opened PatchDB");
let net_controller = Arc::new(
NetController::init(
base.tor_control
.unwrap_or(SocketAddr::from(([127, 0, 0, 1], 9051))),
tor_proxy,
base.dns_bind
.as_deref()
.unwrap_or(&[SocketAddr::from(([127, 0, 0, 1], 53))]),
SslManager::new(&account, root_ca_start_time().await?)?,
&account.hostname,
&account.key,
)
.await?,
);
tracing::info!("Initialized Net Controller");
let managers = ManagerMap::default();
let metrics_cache = RwLock::<Option<crate::system::Metrics>>::new(None);
let notification_manager = NotificationManager::new(secret_store.clone());
tracing::info!("Initialized Notification Manager");
let tor_proxy_url = format!("socks5h://{tor_proxy}");
let devices = lshw().await?;
let ram = get_mem_info().await?.total.0 as u64 * 1024 * 1024;
let seed = Arc::new(RpcContextSeed {
is_closed: AtomicBool::new(false),
datadir: base.datadir().to_path_buf(),
os_partitions: base.os_partitions,
wifi_interface: base.wifi_interface.clone(),
ethernet_interface: base.ethernet_interface,
disk_guid,
db,
secret_store,
account: RwLock::new(account),
net_controller,
managers,
metrics_cache,
shutdown,
tor_socks: tor_proxy,
notification_manager,
open_authed_websockets: Mutex::new(BTreeMap::new()),
rpc_stream_continuations: Mutex::new(BTreeMap::new()),
wifi_manager: base
.wifi_interface
.map(|i| Arc::new(RwLock::new(WpaCli::init(i)))),
current_secret: Arc::new(
Jwk::generate_ec_key(josekit::jwk::alg::ec::EcCurve::P256).map_err(|e| {
tracing::debug!("{:?}", e);
tracing::error!("Couldn't generate ec key");
Error::new(
color_eyre::eyre::eyre!("Couldn't generate ec key"),
crate::ErrorKind::Unknown,
)
})?,
),
client: Client::builder()
.proxy(Proxy::custom(move |url| {
if url.host_str().map_or(false, |h| h.ends_with(".onion")) {
Some(tor_proxy_url.clone())
} else {
None
}
}))
.build()
.with_kind(crate::ErrorKind::ParseUrl)?,
hardware: Hardware { devices, ram },
start_time: Instant::now(),
});
let res = Self(seed.clone());
res.cleanup_and_initialize().await?;
tracing::info!("Cleaned up transient states");
Ok(res)
}
#[instrument(skip_all)]
pub async fn shutdown(self) -> Result<(), Error> {
self.managers.empty().await?;
self.secret_store.close().await;
self.is_closed.store(true, Ordering::SeqCst);
tracing::info!("RPC Context is shutdown");
// TODO: shutdown http servers
Ok(())
}
#[instrument(skip(self))]
pub async fn cleanup_and_initialize(&self) -> Result<(), Error> {
self.db
.mutate(|f| {
let mut current_dependents = f
.as_package_data()
.keys()?
.into_iter()
.map(|k| (k.clone(), BTreeMap::new()))
.collect::<BTreeMap<_, _>>();
for (package_id, package) in f.as_package_data_mut().as_entries_mut()? {
for (k, v) in package
.as_installed_mut()
.into_iter()
.flat_map(|i| i.clone().into_current_dependencies().into_entries())
.flatten()
{
let mut entry: BTreeMap<_, _> =
current_dependents.remove(&k).unwrap_or_default();
entry.insert(package_id.clone(), v.de()?);
current_dependents.insert(k, entry);
}
}
for (package_id, current_dependents) in current_dependents {
if let Some(deps) = f
.as_package_data_mut()
.as_idx_mut(&package_id)
.and_then(|pde| pde.expect_as_installed_mut().ok())
.map(|i| i.as_installed_mut().as_current_dependents_mut())
{
deps.ser(&CurrentDependents(current_dependents))?;
} else if let Some(deps) = f
.as_package_data_mut()
.as_idx_mut(&package_id)
.and_then(|pde| pde.expect_as_removing_mut().ok())
.map(|i| i.as_removing_mut().as_current_dependents_mut())
{
deps.ser(&CurrentDependents(current_dependents))?;
}
}
Ok(())
})
.await?;
let peek = self.db.peek().await;
for (package_id, package) in peek.as_package_data().as_entries()?.into_iter() {
let action = match package.as_match() {
PackageDataEntryMatchModelRef::Installing(_)
| PackageDataEntryMatchModelRef::Restoring(_)
| PackageDataEntryMatchModelRef::Updating(_) => {
cleanup_failed(self, &package_id).await
}
PackageDataEntryMatchModelRef::Removing(_) => {
uninstall(
self,
self.secret_store.acquire().await?.as_mut(),
&package_id,
)
.await
}
PackageDataEntryMatchModelRef::Installed(m) => {
let version = m.as_manifest().as_version().clone().de()?;
let volumes = m.as_manifest().as_volumes().de()?;
for (volume_id, volume_info) in &*volumes {
let tmp_path = to_tmp_path(volume_info.path_for(
&self.datadir,
&package_id,
&version,
volume_id,
))
.with_kind(ErrorKind::Filesystem)?;
if tokio::fs::metadata(&tmp_path).await.is_ok() {
tokio::fs::remove_dir_all(&tmp_path).await?;
}
}
Ok(())
}
_ => continue,
};
if let Err(e) = action {
tracing::error!("Failed to clean up package {}: {}", package_id, e);
tracing::debug!("{:?}", e);
}
}
let peek = self
.db
.mutate(|v| {
for (_, pde) in v.as_package_data_mut().as_entries_mut()? {
let status = pde
.expect_as_installed_mut()?
.as_installed_mut()
.as_status_mut()
.as_main_mut();
let running = status.clone().de()?.running();
status.ser(&if running {
MainStatus::Starting
} else {
MainStatus::Stopped
})?;
}
Ok(v.clone())
})
.await?;
self.managers.init(self.clone(), peek.clone()).await?;
tracing::info!("Initialized Package Managers");
let mut all_dependency_config_errs = BTreeMap::new();
for (package_id, package) in peek.as_package_data().as_entries()?.into_iter() {
let package = package.clone();
if let Some(current_dependencies) = package
.as_installed()
.and_then(|x| x.as_current_dependencies().de().ok())
{
let manifest = package.as_manifest().de()?;
all_dependency_config_errs.insert(
package_id.clone(),
compute_dependency_config_errs(
self,
&peek,
&manifest,
&current_dependencies,
&Default::default(),
)
.await?,
);
}
}
self.db
.mutate(|v| {
for (package_id, errs) in all_dependency_config_errs {
if let Some(config_errors) = v
.as_package_data_mut()
.as_idx_mut(&package_id)
.and_then(|pde| pde.as_installed_mut())
.map(|i| i.as_status_mut().as_dependency_config_errors_mut())
{
config_errors.ser(&errs)?;
}
}
Ok(())
})
.await?;
Ok(())
}
#[instrument(skip_all)]
pub async fn clean_continuations(&self) {
let mut continuations = self.rpc_stream_continuations.lock().await;
let mut to_remove = Vec::new();
for (guid, cont) in &*continuations {
if cont.is_timed_out() {
to_remove.push(guid.clone());
}
}
for guid in to_remove {
continuations.remove(&guid);
}
}
#[instrument(skip_all)]
pub async fn add_continuation(&self, guid: RequestGuid, handler: RpcContinuation) {
self.clean_continuations().await;
self.rpc_stream_continuations
.lock()
.await
.insert(guid, handler);
}
pub async fn get_continuation_handler(&self, guid: &RequestGuid) -> Option<RestHandler> {
let mut continuations = self.rpc_stream_continuations.lock().await;
if let Some(cont) = continuations.remove(guid) {
cont.into_handler().await
} else {
None
}
}
pub async fn get_ws_continuation_handler(&self, guid: &RequestGuid) -> Option<RestHandler> {
let continuations = self.rpc_stream_continuations.lock().await;
if matches!(continuations.get(guid), Some(RpcContinuation::WebSocket(_))) {
drop(continuations);
self.get_continuation_handler(guid).await
} else {
None
}
}
pub async fn get_rest_continuation_handler(&self, guid: &RequestGuid) -> Option<RestHandler> {
let continuations = self.rpc_stream_continuations.lock().await;
if matches!(continuations.get(guid), Some(RpcContinuation::Rest(_))) {
drop(continuations);
self.get_continuation_handler(guid).await
} else {
None
}
}
}
impl AsRef<Jwk> for RpcContext {
fn as_ref(&self) -> &Jwk {
&CURRENT_SECRET
}
}
impl Context for RpcContext {}
impl Deref for RpcContext {
type Target = RpcContextSeed;
fn deref(&self) -> &Self::Target {
#[cfg(feature = "unstable")]
if self.0.is_closed.load(Ordering::SeqCst) {
panic!(
"RpcContext used after shutdown! {}",
tracing_error::SpanTrace::capture()
);
}
&self.0
}
}
impl Drop for RpcContext {
fn drop(&mut self) {
#[cfg(feature = "unstable")]
if self.0.is_closed.load(Ordering::SeqCst) {
tracing::info!(
"RpcContext dropped. {} left.",
Arc::strong_count(&self.0) - 1
);
}
}
}

View File

@@ -0,0 +1,76 @@
use std::path::{Path, PathBuf};
use std::sync::Arc;
use clap::ArgMatches;
use color_eyre::eyre::eyre;
use rpc_toolkit::Context;
use serde::Deserialize;
use tracing::instrument;
use crate::prelude::*;
use crate::util::config::{load_config_from_paths, local_config_path};
#[derive(Debug, Default, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct SdkContextConfig {
pub developer_key_path: Option<PathBuf>,
}
#[derive(Debug)]
pub struct SdkContextSeed {
pub developer_key_path: PathBuf,
}
#[derive(Debug, Clone)]
pub struct SdkContext(Arc<SdkContextSeed>);
impl SdkContext {
/// BLOCKING
#[instrument(skip_all)]
pub fn init(matches: &ArgMatches) -> Result<Self, crate::Error> {
let local_config_path = local_config_path();
let base: SdkContextConfig = load_config_from_paths(
matches
.values_of("config")
.into_iter()
.flatten()
.map(|p| Path::new(p))
.chain(local_config_path.as_deref().into_iter())
.chain(std::iter::once(Path::new(crate::util::config::CONFIG_PATH))),
)?;
Ok(SdkContext(Arc::new(SdkContextSeed {
developer_key_path: base.developer_key_path.unwrap_or_else(|| {
local_config_path
.as_deref()
.unwrap_or_else(|| Path::new(crate::util::config::CONFIG_PATH))
.parent()
.unwrap_or(Path::new("/"))
.join("developer.key.pem")
}),
})))
}
/// BLOCKING
#[instrument(skip_all)]
pub fn developer_key(&self) -> Result<ed25519_dalek::SigningKey, Error> {
if !self.developer_key_path.exists() {
return Err(Error::new(eyre!("Developer Key does not exist! Please run `start-sdk init` before running this command."), crate::ErrorKind::Uninitialized));
}
let pair = <ed25519::KeypairBytes as ed25519::pkcs8::DecodePrivateKey>::from_pkcs8_pem(
&std::fs::read_to_string(&self.developer_key_path)?,
)
.with_kind(crate::ErrorKind::Pem)?;
let secret = ed25519_dalek::SecretKey::try_from(&pair.secret_key[..]).map_err(|_| {
Error::new(
eyre!("pkcs8 key is of incorrect length"),
ErrorKind::OpenSsl,
)
})?;
Ok(secret.into())
}
}
impl std::ops::Deref for SdkContext {
type Target = SdkContextSeed;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
impl Context for SdkContext {}

View File

@@ -0,0 +1,149 @@
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use josekit::jwk::Jwk;
use patch_db::json_ptr::JsonPointer;
use patch_db::PatchDb;
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::Context;
use serde::{Deserialize, Serialize};
use sqlx::postgres::PgConnectOptions;
use sqlx::PgPool;
use tokio::sync::broadcast::Sender;
use tokio::sync::RwLock;
use tracing::instrument;
use crate::account::AccountInfo;
use crate::db::model::Database;
use crate::disk::OsPartitionInfo;
use crate::init::init_postgres;
use crate::setup::SetupStatus;
use crate::util::config::load_config_from_paths;
use crate::{Error, ResultExt};
lazy_static::lazy_static! {
pub static ref CURRENT_SECRET: Jwk = Jwk::generate_ec_key(josekit::jwk::alg::ec::EcCurve::P256).unwrap_or_else(|e| {
tracing::debug!("{:?}", e);
tracing::error!("Couldn't generate ec key");
panic!("Couldn't generate ec key")
});
}
#[derive(Clone, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct SetupResult {
pub tor_address: String,
pub lan_address: String,
pub root_ca: String,
}
#[derive(Debug, Default, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct SetupContextConfig {
pub os_partitions: OsPartitionInfo,
pub migration_batch_rows: Option<usize>,
pub migration_prefetch_rows: Option<usize>,
pub datadir: Option<PathBuf>,
#[serde(default)]
pub disable_encryption: bool,
}
impl SetupContextConfig {
#[instrument(skip_all)]
pub async fn load<P: AsRef<Path> + Send + 'static>(path: Option<P>) -> Result<Self, Error> {
tokio::task::spawn_blocking(move || {
load_config_from_paths(
path.as_ref()
.into_iter()
.map(|p| p.as_ref())
.chain(std::iter::once(Path::new(
crate::util::config::DEVICE_CONFIG_PATH,
)))
.chain(std::iter::once(Path::new(crate::util::config::CONFIG_PATH))),
)
})
.await
.unwrap()
}
pub fn datadir(&self) -> &Path {
self.datadir
.as_deref()
.unwrap_or_else(|| Path::new("/embassy-data"))
}
}
pub struct SetupContextSeed {
pub os_partitions: OsPartitionInfo,
pub config_path: Option<PathBuf>,
pub migration_batch_rows: usize,
pub migration_prefetch_rows: usize,
pub disable_encryption: bool,
pub shutdown: Sender<()>,
pub datadir: PathBuf,
pub selected_v2_drive: RwLock<Option<PathBuf>>,
pub cached_product_key: RwLock<Option<Arc<String>>>,
pub setup_status: RwLock<Option<Result<SetupStatus, RpcError>>>,
pub setup_result: RwLock<Option<(Arc<String>, SetupResult)>>,
}
impl AsRef<Jwk> for SetupContextSeed {
fn as_ref(&self) -> &Jwk {
&*CURRENT_SECRET
}
}
#[derive(Clone)]
pub struct SetupContext(Arc<SetupContextSeed>);
impl SetupContext {
#[instrument(skip_all)]
pub async fn init<P: AsRef<Path> + Send + 'static>(path: Option<P>) -> Result<Self, Error> {
let cfg = SetupContextConfig::load(path.as_ref().map(|p| p.as_ref().to_owned())).await?;
let (shutdown, _) = tokio::sync::broadcast::channel(1);
let datadir = cfg.datadir().to_owned();
Ok(Self(Arc::new(SetupContextSeed {
os_partitions: cfg.os_partitions,
config_path: path.as_ref().map(|p| p.as_ref().to_owned()),
migration_batch_rows: cfg.migration_batch_rows.unwrap_or(25000),
migration_prefetch_rows: cfg.migration_prefetch_rows.unwrap_or(100_000),
disable_encryption: cfg.disable_encryption,
shutdown,
datadir,
selected_v2_drive: RwLock::new(None),
cached_product_key: RwLock::new(None),
setup_status: RwLock::new(None),
setup_result: RwLock::new(None),
})))
}
#[instrument(skip_all)]
pub async fn db(&self, account: &AccountInfo) -> Result<PatchDb, Error> {
let db_path = self.datadir.join("main").join("embassy.db");
let db = PatchDb::open(&db_path)
.await
.with_ctx(|_| (crate::ErrorKind::Filesystem, db_path.display().to_string()))?;
if !db.exists(&<JsonPointer>::default()).await {
db.put(&<JsonPointer>::default(), &Database::init(account))
.await?;
}
Ok(db)
}
#[instrument(skip_all)]
pub async fn secret_store(&self) -> Result<PgPool, Error> {
init_postgres(&self.datadir).await?;
let secret_store =
PgPool::connect_with(PgConnectOptions::new().database("secrets").username("root"))
.await?;
sqlx::migrate!()
.run(&secret_store)
.await
.with_kind(crate::ErrorKind::Database)?;
Ok(secret_store)
}
}
impl Context for SetupContext {}
impl Deref for SetupContext {
type Target = SetupContextSeed;
fn deref(&self) -> &Self::Target {
&*self.0
}
}

View File

@@ -0,0 +1,92 @@
use color_eyre::eyre::eyre;
use rpc_toolkit::command;
use tracing::instrument;
use crate::context::RpcContext;
use crate::prelude::*;
use crate::s9pk::manifest::PackageId;
use crate::status::MainStatus;
use crate::util::display_none;
use crate::Error;
#[command(display(display_none), metadata(sync_db = true))]
#[instrument(skip_all)]
pub async fn start(#[context] ctx: RpcContext, #[arg] id: PackageId) -> Result<(), Error> {
let peek = ctx.db.peek().await;
let version = peek
.as_package_data()
.as_idx(&id)
.or_not_found(&id)?
.as_installed()
.or_not_found(&id)?
.as_manifest()
.as_version()
.de()?;
ctx.managers
.get(&(id, version))
.await
.ok_or_else(|| Error::new(eyre!("Manager not found"), crate::ErrorKind::InvalidRequest))?
.start()
.await;
Ok(())
}
#[command(display(display_none), metadata(sync_db = true))]
pub async fn stop(#[context] ctx: RpcContext, #[arg] id: PackageId) -> Result<MainStatus, Error> {
let peek = ctx.db.peek().await;
let version = peek
.as_package_data()
.as_idx(&id)
.or_not_found(&id)?
.as_installed()
.or_not_found(&id)?
.as_manifest()
.as_version()
.de()?;
let last_statuts = ctx
.db
.mutate(|v| {
v.as_package_data_mut()
.as_idx_mut(&id)
.and_then(|x| x.as_installed_mut())
.ok_or_else(|| Error::new(eyre!("{} is not installed", id), ErrorKind::NotFound))?
.as_status_mut()
.as_main_mut()
.replace(&MainStatus::Stopping)
})
.await?;
ctx.managers
.get(&(id, version))
.await
.ok_or_else(|| Error::new(eyre!("Manager not found"), crate::ErrorKind::InvalidRequest))?
.stop()
.await;
Ok(last_statuts)
}
#[command(display(display_none), metadata(sync_db = true))]
pub async fn restart(#[context] ctx: RpcContext, #[arg] id: PackageId) -> Result<(), Error> {
let peek = ctx.db.peek().await;
let version = peek
.as_package_data()
.as_idx(&id)
.or_not_found(&id)?
.expect_as_installed()?
.as_manifest()
.as_version()
.de()?;
ctx.managers
.get(&(id, version))
.await
.ok_or_else(|| Error::new(eyre!("Manager not found"), crate::ErrorKind::InvalidRequest))?
.restart()
.await;
Ok(())
}

View File

@@ -0,0 +1 @@
pub mod rpc_continuations;

View File

@@ -0,0 +1,116 @@
use std::sync::Arc;
use std::time::Duration;
use futures::future::BoxFuture;
use futures::FutureExt;
use helpers::TimedResource;
use hyper::upgrade::Upgraded;
use hyper::{Body, Error as HyperError, Request, Response};
use rand::RngCore;
use tokio::task::JoinError;
use tokio_tungstenite::WebSocketStream;
use crate::{Error, ResultExt};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
pub struct RequestGuid<T: AsRef<str> = String>(Arc<T>);
impl RequestGuid {
pub fn new() -> Self {
let mut buf = [0; 40];
rand::thread_rng().fill_bytes(&mut buf);
RequestGuid(Arc::new(base32::encode(
base32::Alphabet::RFC4648 { padding: false },
&buf,
)))
}
pub fn from(r: &str) -> Option<RequestGuid> {
if r.len() != 64 {
return None;
}
for c in r.chars() {
if !(c >= 'A' && c <= 'Z' || c >= '2' && c <= '7') {
return None;
}
}
Some(RequestGuid(Arc::new(r.to_owned())))
}
}
#[test]
fn parse_guid() {
println!(
"{:?}",
RequestGuid::from(&format!("{}", RequestGuid::new()))
)
}
impl<T: AsRef<str>> std::fmt::Display for RequestGuid<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
(&*self.0).as_ref().fmt(f)
}
}
pub type RestHandler = Box<
dyn FnOnce(Request<Body>) -> BoxFuture<'static, Result<Response<Body>, crate::Error>> + Send,
>;
pub type WebSocketHandler = Box<
dyn FnOnce(
BoxFuture<'static, Result<Result<WebSocketStream<Upgraded>, HyperError>, JoinError>>,
) -> BoxFuture<'static, Result<(), Error>>
+ Send,
>;
pub enum RpcContinuation {
Rest(TimedResource<RestHandler>),
WebSocket(TimedResource<WebSocketHandler>),
}
impl RpcContinuation {
pub fn rest(handler: RestHandler, timeout: Duration) -> Self {
RpcContinuation::Rest(TimedResource::new(handler, timeout))
}
pub fn ws(handler: WebSocketHandler, timeout: Duration) -> Self {
RpcContinuation::WebSocket(TimedResource::new(handler, timeout))
}
pub fn is_timed_out(&self) -> bool {
match self {
RpcContinuation::Rest(a) => a.is_timed_out(),
RpcContinuation::WebSocket(a) => a.is_timed_out(),
}
}
pub async fn into_handler(self) -> Option<RestHandler> {
match self {
RpcContinuation::Rest(handler) => handler.get().await,
RpcContinuation::WebSocket(handler) => {
if let Some(handler) = handler.get().await {
Some(Box::new(
|req: Request<Body>| -> BoxFuture<'static, Result<Response<Body>, Error>> {
async move {
let (parts, body) = req.into_parts();
let req = Request::from_parts(parts, body);
let (res, ws_fut) = hyper_ws_listener::create_ws(req)
.with_kind(crate::ErrorKind::Network)?;
if let Some(ws_fut) = ws_fut {
tokio::task::spawn(async move {
match handler(ws_fut.boxed()).await {
Ok(()) => (),
Err(e) => {
tracing::error!("WebSocket Closed: {}", e);
tracing::debug!("{:?}", e);
}
}
});
}
Ok(res)
}
.boxed()
},
))
} else {
None
}
}
}
}
}

370
core/startos/src/db/mod.rs Normal file
View File

@@ -0,0 +1,370 @@
pub mod model;
pub mod package;
pub mod prelude;
use std::future::Future;
use std::path::PathBuf;
use std::sync::Arc;
use futures::{FutureExt, SinkExt, StreamExt};
use patch_db::json_ptr::JsonPointer;
use patch_db::{Dump, Revision};
use rpc_toolkit::command;
use rpc_toolkit::hyper::upgrade::Upgraded;
use rpc_toolkit::hyper::{Body, Error as HyperError, Request, Response};
use rpc_toolkit::yajrc::RpcError;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::oneshot;
use tokio::task::JoinError;
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
use tokio_tungstenite::tungstenite::protocol::CloseFrame;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::WebSocketStream;
use tracing::instrument;
use crate::context::{CliContext, RpcContext};
use crate::middleware::auth::{HasValidSession, HashSessionToken};
use crate::prelude::*;
use crate::util::display_none;
use crate::util::serde::{display_serializable, IoFormat};
#[instrument(skip_all)]
async fn ws_handler<
WSFut: Future<Output = Result<Result<WebSocketStream<Upgraded>, HyperError>, JoinError>>,
>(
ctx: RpcContext,
session: Option<(HasValidSession, HashSessionToken)>,
ws_fut: WSFut,
) -> Result<(), Error> {
let (dump, sub) = ctx.db.dump_and_sub().await;
let mut stream = ws_fut
.await
.with_kind(ErrorKind::Network)?
.with_kind(ErrorKind::Unknown)?;
if let Some((session, token)) = session {
let kill = subscribe_to_session_kill(&ctx, token).await;
send_dump(session, &mut stream, dump).await?;
deal_with_messages(session, kill, sub, stream).await?;
} else {
stream
.close(Some(CloseFrame {
code: CloseCode::Error,
reason: "UNAUTHORIZED".into(),
}))
.await
.with_kind(ErrorKind::Network)?;
}
Ok(())
}
async fn subscribe_to_session_kill(
ctx: &RpcContext,
token: HashSessionToken,
) -> oneshot::Receiver<()> {
let (send, recv) = oneshot::channel();
let mut guard = ctx.open_authed_websockets.lock().await;
if !guard.contains_key(&token) {
guard.insert(token, vec![send]);
} else {
guard.get_mut(&token).unwrap().push(send);
}
recv
}
#[instrument(skip_all)]
async fn deal_with_messages(
_has_valid_authentication: HasValidSession,
mut kill: oneshot::Receiver<()>,
mut sub: patch_db::Subscriber,
mut stream: WebSocketStream<Upgraded>,
) -> Result<(), Error> {
let mut timer = tokio::time::interval(tokio::time::Duration::from_secs(5));
loop {
futures::select! {
_ = (&mut kill).fuse() => {
tracing::info!("Closing WebSocket: Reason: Session Terminated");
stream
.close(Some(CloseFrame {
code: CloseCode::Error,
reason: "UNAUTHORIZED".into(),
}))
.await
.with_kind(ErrorKind::Network)?;
return Ok(())
}
new_rev = sub.recv().fuse() => {
let rev = new_rev.expect("UNREACHABLE: patch-db is dropped");
stream
.send(Message::Text(serde_json::to_string(&rev).with_kind(ErrorKind::Serialization)?))
.await
.with_kind(ErrorKind::Network)?;
}
message = stream.next().fuse() => {
let message = message.transpose().with_kind(ErrorKind::Network)?;
match message {
None => {
tracing::info!("Closing WebSocket: Stream Finished");
return Ok(())
}
_ => (),
}
}
// This is trying to give a health checks to the home to keep the ui alive.
_ = timer.tick().fuse() => {
stream
.send(Message::Ping(vec![]))
.await
.with_kind(crate::ErrorKind::Network)?;
}
}
}
}
async fn send_dump(
_has_valid_authentication: HasValidSession,
stream: &mut WebSocketStream<Upgraded>,
dump: Dump,
) -> Result<(), Error> {
stream
.send(Message::Text(
serde_json::to_string(&dump).with_kind(ErrorKind::Serialization)?,
))
.await
.with_kind(ErrorKind::Network)?;
Ok(())
}
pub async fn subscribe(ctx: RpcContext, req: Request<Body>) -> Result<Response<Body>, Error> {
let (parts, body) = req.into_parts();
let session = match async {
let token = HashSessionToken::from_request_parts(&parts)?;
let session = HasValidSession::from_request_parts(&parts, &ctx).await?;
Ok::<_, Error>((session, token))
}
.await
{
Ok(a) => Some(a),
Err(e) => {
if e.kind != ErrorKind::Authorization {
tracing::error!("Error Authenticating Websocket: {}", e);
tracing::debug!("{:?}", e);
}
None
}
};
let req = Request::from_parts(parts, body);
let (res, ws_fut) = hyper_ws_listener::create_ws(req).with_kind(ErrorKind::Network)?;
if let Some(ws_fut) = ws_fut {
tokio::task::spawn(async move {
match ws_handler(ctx, session, ws_fut).await {
Ok(()) => (),
Err(e) => {
tracing::error!("WebSocket Closed: {}", e);
tracing::debug!("{:?}", e);
}
}
});
}
Ok(res)
}
#[command(subcommands(dump, put, apply))]
pub fn db() -> Result<(), RpcError> {
Ok(())
}
#[derive(Deserialize, Serialize)]
#[serde(untagged)]
pub enum RevisionsRes {
Revisions(Vec<Arc<Revision>>),
Dump(Dump),
}
#[instrument(skip_all)]
async fn cli_dump(
ctx: CliContext,
_format: Option<IoFormat>,
path: Option<PathBuf>,
) -> Result<Dump, RpcError> {
let dump = if let Some(path) = path {
PatchDb::open(path).await?.dump().await
} else {
rpc_toolkit::command_helpers::call_remote(
ctx,
"db.dump",
serde_json::json!({}),
std::marker::PhantomData::<Dump>,
)
.await?
.result?
};
Ok(dump)
}
#[command(
custom_cli(cli_dump(async, context(CliContext))),
display(display_serializable)
)]
pub async fn dump(
#[context] ctx: RpcContext,
#[allow(unused_variables)]
#[arg(long = "format")]
format: Option<IoFormat>,
#[allow(unused_variables)]
#[arg]
path: Option<PathBuf>,
) -> Result<Dump, Error> {
Ok(ctx.db.dump().await)
}
fn apply_expr(input: jaq_core::Val, expr: &str) -> Result<jaq_core::Val, Error> {
let (expr, errs) = jaq_core::parse::parse(expr, jaq_core::parse::main());
let Some(expr) = expr else {
return Err(Error::new(
eyre!("Failed to parse expression: {:?}", errs),
crate::ErrorKind::InvalidRequest,
));
};
let mut errs = Vec::new();
let mut defs = jaq_core::Definitions::core();
for def in jaq_std::std() {
defs.insert(def, &mut errs);
}
let filter = defs.finish(expr, Vec::new(), &mut errs);
if !errs.is_empty() {
return Err(Error::new(
eyre!("Failed to compile expression: {:?}", errs),
crate::ErrorKind::InvalidRequest,
));
};
let inputs = jaq_core::RcIter::new(std::iter::empty());
let mut res_iter = filter.run(jaq_core::Ctx::new([], &inputs), input);
let Some(res) = res_iter
.next()
.transpose()
.map_err(|e| eyre!("{e}"))
.with_kind(crate::ErrorKind::Deserialization)?
else {
return Err(Error::new(
eyre!("expr returned no results"),
crate::ErrorKind::InvalidRequest,
));
};
if res_iter.next().is_some() {
return Err(Error::new(
eyre!("expr returned too many results"),
crate::ErrorKind::InvalidRequest,
));
}
Ok(res)
}
#[instrument(skip_all)]
async fn cli_apply(ctx: CliContext, expr: String, path: Option<PathBuf>) -> Result<(), RpcError> {
if let Some(path) = path {
PatchDb::open(path)
.await?
.mutate(|db| {
let res = apply_expr(
serde_json::to_value(patch_db::Value::from(db.clone()))
.with_kind(ErrorKind::Deserialization)?
.into(),
&expr,
)?;
db.ser(
&serde_json::from_value::<model::Database>(res.clone().into()).with_ctx(
|_| {
(
crate::ErrorKind::Deserialization,
"result does not match database model",
)
},
)?,
)
})
.await?;
} else {
rpc_toolkit::command_helpers::call_remote(
ctx,
"db.apply",
serde_json::json!({ "expr": expr }),
std::marker::PhantomData::<()>,
)
.await?
.result?;
}
Ok(())
}
#[command(
custom_cli(cli_apply(async, context(CliContext))),
display(display_none)
)]
pub async fn apply(
#[context] ctx: RpcContext,
#[arg] expr: String,
#[allow(unused_variables)]
#[arg]
path: Option<PathBuf>,
) -> Result<(), Error> {
ctx.db
.mutate(|db| {
let res = apply_expr(
serde_json::to_value(patch_db::Value::from(db.clone()))
.with_kind(ErrorKind::Deserialization)?
.into(),
&expr,
)?;
db.ser(
&serde_json::from_value::<model::Database>(res.clone().into()).with_ctx(|_| {
(
crate::ErrorKind::Deserialization,
"result does not match database model",
)
})?,
)
})
.await
}
#[command(subcommands(ui))]
pub fn put() -> Result<(), RpcError> {
Ok(())
}
#[command(display(display_serializable))]
#[instrument(skip_all)]
pub async fn ui(
#[context] ctx: RpcContext,
#[arg] pointer: JsonPointer,
#[arg] value: Value,
#[allow(unused_variables)]
#[arg(long = "format")]
format: Option<IoFormat>,
) -> Result<(), Error> {
let ptr = "/ui"
.parse::<JsonPointer>()
.with_kind(ErrorKind::Database)?
+ &pointer;
ctx.db.put(&ptr, &value).await?;
Ok(())
}

View File

@@ -0,0 +1,527 @@
use std::collections::{BTreeMap, BTreeSet};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
use chrono::{DateTime, Utc};
use emver::VersionRange;
use imbl_value::InternedString;
use ipnet::{Ipv4Net, Ipv6Net};
use isocountry::CountryCode;
use itertools::Itertools;
use models::{DataUrl, HealthCheckId, InterfaceId};
use openssl::hash::MessageDigest;
use patch_db::{HasModel, Value};
use reqwest::Url;
use serde::{Deserialize, Serialize};
use ssh_key::public::Ed25519PublicKey;
use crate::account::AccountInfo;
use crate::config::spec::PackagePointerSpec;
use crate::install::progress::InstallProgress;
use crate::net::utils::{get_iface_ipv4_addr, get_iface_ipv6_addr};
use crate::prelude::*;
use crate::s9pk::manifest::{Manifest, PackageId};
use crate::status::Status;
use crate::util::Version;
use crate::version::{Current, VersionT};
use crate::{ARCH, PLATFORM};
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
// #[macro_debug]
pub struct Database {
pub server_info: ServerInfo,
pub package_data: AllPackageData,
pub ui: Value,
}
impl Database {
pub fn init(account: &AccountInfo) -> Self {
let lan_address = account.hostname.lan_address().parse().unwrap();
Database {
server_info: ServerInfo {
arch: get_arch(),
platform: get_platform(),
id: account.server_id.clone(),
version: Current::new().semver().into(),
hostname: account.hostname.no_dot_host_name(),
last_backup: None,
last_wifi_region: None,
eos_version_compat: Current::new().compat().clone(),
lan_address,
tor_address: format!("https://{}", account.key.tor_address())
.parse()
.unwrap(),
ip_info: BTreeMap::new(),
status_info: ServerStatus {
backup_progress: None,
updated: false,
update_progress: None,
shutting_down: false,
restarting: false,
},
wifi: WifiInfo {
ssids: Vec::new(),
connected: None,
selected: None,
},
unread_notification_count: 0,
connection_addresses: ConnectionAddresses {
tor: Vec::new(),
clearnet: Vec::new(),
},
password_hash: account.password.clone(),
pubkey: ssh_key::PublicKey::from(Ed25519PublicKey::from(&account.key.ssh_key()))
.to_openssh()
.unwrap(),
ca_fingerprint: account
.root_ca_cert
.digest(MessageDigest::sha256())
.unwrap()
.iter()
.map(|x| format!("{x:X}"))
.join(":"),
ntp_synced: false,
zram: true,
},
package_data: AllPackageData::default(),
ui: serde_json::from_str(include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../web/patchdb-ui-seed.json"
)))
.unwrap(),
}
}
}
pub type DatabaseModel = Model<Database>;
fn get_arch() -> InternedString {
(*ARCH).into()
}
fn get_platform() -> InternedString {
(&*PLATFORM).into()
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct ServerInfo {
#[serde(default = "get_arch")]
pub arch: InternedString,
#[serde(default = "get_platform")]
pub platform: InternedString,
pub id: String,
pub hostname: String,
pub version: Version,
pub last_backup: Option<DateTime<Utc>>,
/// Used in the wifi to determine the region to set the system to
pub last_wifi_region: Option<CountryCode>,
pub eos_version_compat: VersionRange,
pub lan_address: Url,
pub tor_address: Url,
pub ip_info: BTreeMap<String, IpInfo>,
#[serde(default)]
pub status_info: ServerStatus,
pub wifi: WifiInfo,
pub unread_notification_count: u64,
pub connection_addresses: ConnectionAddresses,
pub password_hash: String,
pub pubkey: String,
pub ca_fingerprint: String,
#[serde(default)]
pub ntp_synced: bool,
#[serde(default)]
pub zram: bool,
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct IpInfo {
pub ipv4_range: Option<Ipv4Net>,
pub ipv4: Option<Ipv4Addr>,
pub ipv6_range: Option<Ipv6Net>,
pub ipv6: Option<Ipv6Addr>,
}
impl IpInfo {
pub async fn for_interface(iface: &str) -> Result<Self, Error> {
let (ipv4, ipv4_range) = get_iface_ipv4_addr(iface).await?.unzip();
let (ipv6, ipv6_range) = get_iface_ipv6_addr(iface).await?.unzip();
Ok(Self {
ipv4_range,
ipv4,
ipv6_range,
ipv6,
})
}
}
#[derive(Debug, Default, Deserialize, Serialize, HasModel)]
#[model = "Model<Self>"]
pub struct BackupProgress {
pub complete: bool,
}
#[derive(Debug, Default, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct ServerStatus {
pub backup_progress: Option<BTreeMap<PackageId, BackupProgress>>,
pub updated: bool,
pub update_progress: Option<UpdateProgress>,
#[serde(default)]
pub shutting_down: bool,
#[serde(default)]
pub restarting: bool,
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct UpdateProgress {
pub size: Option<u64>,
pub downloaded: u64,
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct WifiInfo {
pub ssids: Vec<String>,
pub selected: Option<String>,
pub connected: Option<String>,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct ServerSpecs {
pub cpu: String,
pub disk: String,
pub memory: String,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct ConnectionAddresses {
pub tor: Vec<String>,
pub clearnet: Vec<String>,
}
#[derive(Debug, Default, Deserialize, Serialize)]
pub struct AllPackageData(pub BTreeMap<PackageId, PackageDataEntry>);
impl Map for AllPackageData {
type Key = PackageId;
type Value = PackageDataEntry;
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct StaticFiles {
license: String,
instructions: String,
icon: String,
}
impl StaticFiles {
pub fn local(id: &PackageId, version: &Version, icon_type: &str) -> Self {
StaticFiles {
license: format!("/public/package-data/{}/{}/LICENSE.md", id, version),
instructions: format!("/public/package-data/{}/{}/INSTRUCTIONS.md", id, version),
icon: format!("/public/package-data/{}/{}/icon.{}", id, version, icon_type),
}
}
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct PackageDataEntryInstalling {
pub static_files: StaticFiles,
pub manifest: Manifest,
pub install_progress: Arc<InstallProgress>,
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct PackageDataEntryUpdating {
pub static_files: StaticFiles,
pub manifest: Manifest,
pub installed: InstalledPackageInfo,
pub install_progress: Arc<InstallProgress>,
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct PackageDataEntryRestoring {
pub static_files: StaticFiles,
pub manifest: Manifest,
pub install_progress: Arc<InstallProgress>,
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct PackageDataEntryRemoving {
pub static_files: StaticFiles,
pub manifest: Manifest,
pub removing: InstalledPackageInfo,
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct PackageDataEntryInstalled {
pub static_files: StaticFiles,
pub manifest: Manifest,
pub installed: InstalledPackageInfo,
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(tag = "state")]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
// #[macro_debug]
pub enum PackageDataEntry {
Installing(PackageDataEntryInstalling),
Updating(PackageDataEntryUpdating),
Restoring(PackageDataEntryRestoring),
Removing(PackageDataEntryRemoving),
Installed(PackageDataEntryInstalled),
}
impl Model<PackageDataEntry> {
pub fn expect_into_installed(self) -> Result<Model<PackageDataEntryInstalled>, Error> {
if let PackageDataEntryMatchModel::Installed(a) = self.into_match() {
Ok(a)
} else {
Err(Error::new(
eyre!("package is not in installed state"),
ErrorKind::InvalidRequest,
))
}
}
pub fn expect_as_installed(&self) -> Result<&Model<PackageDataEntryInstalled>, Error> {
if let PackageDataEntryMatchModelRef::Installed(a) = self.as_match() {
Ok(a)
} else {
Err(Error::new(
eyre!("package is not in installed state"),
ErrorKind::InvalidRequest,
))
}
}
pub fn expect_as_installed_mut(
&mut self,
) -> Result<&mut Model<PackageDataEntryInstalled>, Error> {
if let PackageDataEntryMatchModelMut::Installed(a) = self.as_match_mut() {
Ok(a)
} else {
Err(Error::new(
eyre!("package is not in installed state"),
ErrorKind::InvalidRequest,
))
}
}
pub fn expect_into_removing(self) -> Result<Model<PackageDataEntryRemoving>, Error> {
if let PackageDataEntryMatchModel::Removing(a) = self.into_match() {
Ok(a)
} else {
Err(Error::new(
eyre!("package is not in removing state"),
ErrorKind::InvalidRequest,
))
}
}
pub fn expect_as_removing(&self) -> Result<&Model<PackageDataEntryRemoving>, Error> {
if let PackageDataEntryMatchModelRef::Removing(a) = self.as_match() {
Ok(a)
} else {
Err(Error::new(
eyre!("package is not in removing state"),
ErrorKind::InvalidRequest,
))
}
}
pub fn expect_as_removing_mut(
&mut self,
) -> Result<&mut Model<PackageDataEntryRemoving>, Error> {
if let PackageDataEntryMatchModelMut::Removing(a) = self.as_match_mut() {
Ok(a)
} else {
Err(Error::new(
eyre!("package is not in removing state"),
ErrorKind::InvalidRequest,
))
}
}
pub fn expect_as_installing_mut(
&mut self,
) -> Result<&mut Model<PackageDataEntryInstalling>, Error> {
if let PackageDataEntryMatchModelMut::Installing(a) = self.as_match_mut() {
Ok(a)
} else {
Err(Error::new(
eyre!("package is not in installing state"),
ErrorKind::InvalidRequest,
))
}
}
pub fn into_manifest(self) -> Model<Manifest> {
match self.into_match() {
PackageDataEntryMatchModel::Installing(a) => a.into_manifest(),
PackageDataEntryMatchModel::Updating(a) => a.into_installed().into_manifest(),
PackageDataEntryMatchModel::Restoring(a) => a.into_manifest(),
PackageDataEntryMatchModel::Removing(a) => a.into_manifest(),
PackageDataEntryMatchModel::Installed(a) => a.into_manifest(),
PackageDataEntryMatchModel::Error(_) => Model::from(Value::Null),
}
}
pub fn as_manifest(&self) -> &Model<Manifest> {
match self.as_match() {
PackageDataEntryMatchModelRef::Installing(a) => a.as_manifest(),
PackageDataEntryMatchModelRef::Updating(a) => a.as_installed().as_manifest(),
PackageDataEntryMatchModelRef::Restoring(a) => a.as_manifest(),
PackageDataEntryMatchModelRef::Removing(a) => a.as_manifest(),
PackageDataEntryMatchModelRef::Installed(a) => a.as_manifest(),
PackageDataEntryMatchModelRef::Error(_) => (&Value::Null).into(),
}
}
pub fn into_installed(self) -> Option<Model<InstalledPackageInfo>> {
match self.into_match() {
PackageDataEntryMatchModel::Installing(_) => None,
PackageDataEntryMatchModel::Updating(a) => Some(a.into_installed()),
PackageDataEntryMatchModel::Restoring(_) => None,
PackageDataEntryMatchModel::Removing(_) => None,
PackageDataEntryMatchModel::Installed(a) => Some(a.into_installed()),
PackageDataEntryMatchModel::Error(_) => None,
}
}
pub fn as_installed(&self) -> Option<&Model<InstalledPackageInfo>> {
match self.as_match() {
PackageDataEntryMatchModelRef::Installing(_) => None,
PackageDataEntryMatchModelRef::Updating(a) => Some(a.as_installed()),
PackageDataEntryMatchModelRef::Restoring(_) => None,
PackageDataEntryMatchModelRef::Removing(_) => None,
PackageDataEntryMatchModelRef::Installed(a) => Some(a.as_installed()),
PackageDataEntryMatchModelRef::Error(_) => None,
}
}
pub fn as_installed_mut(&mut self) -> Option<&mut Model<InstalledPackageInfo>> {
match self.as_match_mut() {
PackageDataEntryMatchModelMut::Installing(_) => None,
PackageDataEntryMatchModelMut::Updating(a) => Some(a.as_installed_mut()),
PackageDataEntryMatchModelMut::Restoring(_) => None,
PackageDataEntryMatchModelMut::Removing(_) => None,
PackageDataEntryMatchModelMut::Installed(a) => Some(a.as_installed_mut()),
PackageDataEntryMatchModelMut::Error(_) => None,
}
}
pub fn as_install_progress(&self) -> Option<&Model<Arc<InstallProgress>>> {
match self.as_match() {
PackageDataEntryMatchModelRef::Installing(a) => Some(a.as_install_progress()),
PackageDataEntryMatchModelRef::Updating(a) => Some(a.as_install_progress()),
PackageDataEntryMatchModelRef::Restoring(a) => Some(a.as_install_progress()),
PackageDataEntryMatchModelRef::Removing(_) => None,
PackageDataEntryMatchModelRef::Installed(_) => None,
PackageDataEntryMatchModelRef::Error(_) => None,
}
}
pub fn as_install_progress_mut(&mut self) -> Option<&mut Model<Arc<InstallProgress>>> {
match self.as_match_mut() {
PackageDataEntryMatchModelMut::Installing(a) => Some(a.as_install_progress_mut()),
PackageDataEntryMatchModelMut::Updating(a) => Some(a.as_install_progress_mut()),
PackageDataEntryMatchModelMut::Restoring(a) => Some(a.as_install_progress_mut()),
PackageDataEntryMatchModelMut::Removing(_) => None,
PackageDataEntryMatchModelMut::Installed(_) => None,
PackageDataEntryMatchModelMut::Error(_) => None,
}
}
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct InstalledPackageInfo {
pub status: Status,
pub marketplace_url: Option<Url>,
#[serde(default)]
#[serde(with = "crate::util::serde::ed25519_pubkey")]
pub developer_key: ed25519_dalek::VerifyingKey,
pub manifest: Manifest,
pub last_backup: Option<DateTime<Utc>>,
pub dependency_info: BTreeMap<PackageId, StaticDependencyInfo>,
pub current_dependents: CurrentDependents,
pub current_dependencies: CurrentDependencies,
pub interface_addresses: InterfaceAddressMap,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct CurrentDependents(pub BTreeMap<PackageId, CurrentDependencyInfo>);
impl CurrentDependents {
pub fn map(
mut self,
transform: impl Fn(
BTreeMap<PackageId, CurrentDependencyInfo>,
) -> BTreeMap<PackageId, CurrentDependencyInfo>,
) -> Self {
self.0 = transform(self.0);
self
}
}
impl Map for CurrentDependents {
type Key = PackageId;
type Value = CurrentDependencyInfo;
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct CurrentDependencies(pub BTreeMap<PackageId, CurrentDependencyInfo>);
impl CurrentDependencies {
pub fn map(
mut self,
transform: impl Fn(
BTreeMap<PackageId, CurrentDependencyInfo>,
) -> BTreeMap<PackageId, CurrentDependencyInfo>,
) -> Self {
self.0 = transform(self.0);
self
}
}
impl Map for CurrentDependencies {
type Key = PackageId;
type Value = CurrentDependencyInfo;
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct StaticDependencyInfo {
pub title: String,
pub icon: DataUrl<'static>,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct CurrentDependencyInfo {
#[serde(default)]
pub pointers: BTreeSet<PackagePointerSpec>,
pub health_checks: BTreeSet<HealthCheckId>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct InterfaceAddressMap(pub BTreeMap<InterfaceId, InterfaceAddresses>);
impl Map for InterfaceAddressMap {
type Key = InterfaceId;
type Value = InterfaceAddresses;
}
#[derive(Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct InterfaceAddresses {
pub tor_address: Option<String>,
pub lan_address: Option<String>,
}

View File

@@ -0,0 +1,22 @@
use models::Version;
use crate::prelude::*;
use crate::s9pk::manifest::PackageId;
pub fn get_packages(db: Peeked) -> Result<Vec<(PackageId, Version)>, Error> {
Ok(db
.as_package_data()
.keys()?
.into_iter()
.flat_map(|package_id| {
let version = db
.as_package_data()
.as_idx(&package_id)?
.as_manifest()
.as_version()
.de()
.ok()?;
Some((package_id, version))
})
.collect())
}

View File

@@ -0,0 +1,382 @@
use std::collections::BTreeMap;
use std::marker::PhantomData;
use std::panic::UnwindSafe;
use patch_db::value::InternedString;
pub use patch_db::{HasModel, PatchDb, Value};
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::db::model::DatabaseModel;
use crate::prelude::*;
pub type Peeked = Model<super::model::Database>;
pub fn to_value<T>(value: &T) -> Result<Value, Error>
where
T: Serialize,
{
patch_db::value::to_value(value).with_kind(ErrorKind::Serialization)
}
pub fn from_value<T>(value: Value) -> Result<T, Error>
where
T: DeserializeOwned,
{
patch_db::value::from_value(value).with_kind(ErrorKind::Deserialization)
}
#[async_trait::async_trait]
pub trait PatchDbExt {
async fn peek(&self) -> DatabaseModel;
async fn mutate<U: UnwindSafe + Send>(
&self,
f: impl FnOnce(&mut DatabaseModel) -> Result<U, Error> + UnwindSafe + Send,
) -> Result<U, Error>;
async fn map_mutate(
&self,
f: impl FnOnce(DatabaseModel) -> Result<DatabaseModel, Error> + UnwindSafe + Send,
) -> Result<DatabaseModel, Error>;
}
#[async_trait::async_trait]
impl PatchDbExt for PatchDb {
async fn peek(&self) -> DatabaseModel {
DatabaseModel::from(self.dump().await.value)
}
async fn mutate<U: UnwindSafe + Send>(
&self,
f: impl FnOnce(&mut DatabaseModel) -> Result<U, Error> + UnwindSafe + Send,
) -> Result<U, Error> {
Ok(self
.apply_function(|mut v| {
let model = <&mut DatabaseModel>::from(&mut v);
let res = f(model)?;
Ok::<_, Error>((v, res))
})
.await?
.1)
}
async fn map_mutate(
&self,
f: impl FnOnce(DatabaseModel) -> Result<DatabaseModel, Error> + UnwindSafe + Send,
) -> Result<DatabaseModel, Error> {
Ok(DatabaseModel::from(
self.apply_function(|v| f(DatabaseModel::from(v)).map(|a| (a.into(), ())))
.await?
.0,
))
}
}
/// &mut Model<T> <=> &mut Value
#[repr(transparent)]
#[derive(Debug)]
pub struct Model<T> {
value: Value,
phantom: PhantomData<T>,
}
impl<T: DeserializeOwned> Model<T> {
pub fn de(&self) -> Result<T, Error> {
from_value(self.value.clone())
}
}
impl<T: Serialize> Model<T> {
pub fn new(value: &T) -> Result<Self, Error> {
Ok(Self::from(to_value(value)?))
}
pub fn ser(&mut self, value: &T) -> Result<(), Error> {
self.value = to_value(value)?;
Ok(())
}
}
impl<T: Serialize + DeserializeOwned> Model<T> {
pub fn replace(&mut self, value: &T) -> Result<T, Error> {
let orig = self.de()?;
self.ser(value)?;
Ok(orig)
}
}
impl<T> Clone for Model<T> {
fn clone(&self) -> Self {
Self {
value: self.value.clone(),
phantom: PhantomData,
}
}
}
impl<T> From<Value> for Model<T> {
fn from(value: Value) -> Self {
Self {
value,
phantom: PhantomData,
}
}
}
impl<T> From<Model<T>> for Value {
fn from(value: Model<T>) -> Self {
value.value
}
}
impl<'a, T> From<&'a Value> for &'a Model<T> {
fn from(value: &'a Value) -> Self {
unsafe { std::mem::transmute(value) }
}
}
impl<'a, T> From<&'a Model<T>> for &'a Value {
fn from(value: &'a Model<T>) -> Self {
unsafe { std::mem::transmute(value) }
}
}
impl<'a, T> From<&'a mut Value> for &mut Model<T> {
fn from(value: &'a mut Value) -> Self {
unsafe { std::mem::transmute(value) }
}
}
impl<'a, T> From<&'a mut Model<T>> for &mut Value {
fn from(value: &'a mut Model<T>) -> Self {
unsafe { std::mem::transmute(value) }
}
}
impl<T> patch_db::Model<T> for Model<T> {
type Model<U> = Model<U>;
}
impl<T> Model<Option<T>> {
pub fn transpose(self) -> Option<Model<T>> {
use patch_db::ModelExt;
if self.value.is_null() {
None
} else {
Some(self.transmute(|a| a))
}
}
pub fn transpose_ref(&self) -> Option<&Model<T>> {
use patch_db::ModelExt;
if self.value.is_null() {
None
} else {
Some(self.transmute_ref(|a| a))
}
}
pub fn transpose_mut(&mut self) -> Option<&mut Model<T>> {
use patch_db::ModelExt;
if self.value.is_null() {
None
} else {
Some(self.transmute_mut(|a| a))
}
}
pub fn from_option(opt: Option<Model<T>>) -> Self {
use patch_db::ModelExt;
match opt {
Some(a) => a.transmute(|a| a),
None => Self::from_value(Value::Null),
}
}
}
pub trait Map: DeserializeOwned + Serialize {
type Key;
type Value;
}
impl<A, B> Map for BTreeMap<A, B>
where
A: serde::Serialize + serde::de::DeserializeOwned + Ord,
B: serde::Serialize + serde::de::DeserializeOwned,
{
type Key = A;
type Value = B;
}
impl<T: Map> Model<T>
where
T::Key: AsRef<str>,
T::Value: Serialize,
{
pub fn insert(&mut self, key: &T::Key, value: &T::Value) -> Result<(), Error> {
use serde::ser::Error;
let v = patch_db::value::to_value(value)?;
match &mut self.value {
Value::Object(o) => {
o.insert(InternedString::intern(key.as_ref()), v);
Ok(())
}
v => Err(patch_db::value::Error {
source: patch_db::value::ErrorSource::custom(format!("expected object found {v}")),
kind: patch_db::value::ErrorKind::Serialization,
}
.into()),
}
}
pub fn insert_model(&mut self, key: &T::Key, value: Model<T::Value>) -> Result<(), Error> {
use patch_db::ModelExt;
use serde::ser::Error;
let v = value.into_value();
match &mut self.value {
Value::Object(o) => {
o.insert(InternedString::intern(key.as_ref()), v);
Ok(())
}
v => Err(patch_db::value::Error {
source: patch_db::value::ErrorSource::custom(format!("expected object found {v}")),
kind: patch_db::value::ErrorKind::Serialization,
}
.into()),
}
}
}
impl<T: Map> Model<T>
where
T::Key: DeserializeOwned + Ord + Clone,
{
pub fn keys(&self) -> Result<Vec<T::Key>, Error> {
use serde::de::Error;
use serde::Deserialize;
match &self.value {
Value::Object(o) => o
.keys()
.cloned()
.map(|k| {
T::Key::deserialize(patch_db::value::de::InternedStringDeserializer::from(k))
.map_err(|e| {
patch_db::value::Error {
kind: patch_db::value::ErrorKind::Deserialization,
source: e,
}
.into()
})
})
.collect(),
v => Err(patch_db::value::Error {
source: patch_db::value::ErrorSource::custom(format!("expected object found {v}")),
kind: patch_db::value::ErrorKind::Deserialization,
}
.into()),
}
}
pub fn into_entries(self) -> Result<Vec<(T::Key, Model<T::Value>)>, Error> {
use patch_db::ModelExt;
use serde::de::Error;
use serde::Deserialize;
match self.value {
Value::Object(o) => o
.into_iter()
.map(|(k, v)| {
Ok((
T::Key::deserialize(patch_db::value::de::InternedStringDeserializer::from(
k,
))
.with_kind(ErrorKind::Deserialization)?,
Model::from_value(v),
))
})
.collect(),
v => Err(patch_db::value::Error {
source: patch_db::value::ErrorSource::custom(format!("expected object found {v}")),
kind: patch_db::value::ErrorKind::Deserialization,
}
.into()),
}
}
pub fn as_entries(&self) -> Result<Vec<(T::Key, &Model<T::Value>)>, Error> {
use patch_db::ModelExt;
use serde::de::Error;
use serde::Deserialize;
match &self.value {
Value::Object(o) => o
.iter()
.map(|(k, v)| {
Ok((
T::Key::deserialize(patch_db::value::de::InternedStringDeserializer::from(
k.clone(),
))
.with_kind(ErrorKind::Deserialization)?,
Model::value_as(v),
))
})
.collect(),
v => Err(patch_db::value::Error {
source: patch_db::value::ErrorSource::custom(format!("expected object found {v}")),
kind: patch_db::value::ErrorKind::Deserialization,
}
.into()),
}
}
pub fn as_entries_mut(&mut self) -> Result<Vec<(T::Key, &mut Model<T::Value>)>, Error> {
use patch_db::ModelExt;
use serde::de::Error;
use serde::Deserialize;
match &mut self.value {
Value::Object(o) => o
.iter_mut()
.map(|(k, v)| {
Ok((
T::Key::deserialize(patch_db::value::de::InternedStringDeserializer::from(
k.clone(),
))
.with_kind(ErrorKind::Deserialization)?,
Model::value_as_mut(v),
))
})
.collect(),
v => Err(patch_db::value::Error {
source: patch_db::value::ErrorSource::custom(format!("expected object found {v}")),
kind: patch_db::value::ErrorKind::Deserialization,
}
.into()),
}
}
}
impl<T: Map> Model<T>
where
T::Key: AsRef<str>,
{
pub fn into_idx(self, key: &T::Key) -> Option<Model<T::Value>> {
use patch_db::ModelExt;
match &self.value {
Value::Object(o) if o.contains_key(key.as_ref()) => Some(self.transmute(|v| {
use patch_db::value::index::Index;
key.as_ref().index_into_owned(v).unwrap()
})),
_ => None,
}
}
pub fn as_idx<'a>(&'a self, key: &T::Key) -> Option<&'a Model<T::Value>> {
use patch_db::ModelExt;
match &self.value {
Value::Object(o) if o.contains_key(key.as_ref()) => Some(self.transmute_ref(|v| {
use patch_db::value::index::Index;
key.as_ref().index_into(v).unwrap()
})),
_ => None,
}
}
pub fn as_idx_mut<'a>(&'a mut self, key: &T::Key) -> Option<&'a mut Model<T::Value>> {
use patch_db::ModelExt;
match &mut self.value {
Value::Object(o) if o.contains_key(key.as_ref()) => Some(self.transmute_mut(|v| {
use patch_db::value::index::Index;
key.as_ref().index_or_insert(v)
})),
_ => None,
}
}
pub fn remove(&mut self, key: &T::Key) -> Result<Option<Model<T::Value>>, Error> {
use serde::ser::Error;
match &mut self.value {
Value::Object(o) => {
let v = o.remove(key.as_ref());
Ok(v.map(patch_db::ModelExt::from_value))
}
v => Err(patch_db::value::Error {
source: patch_db::value::ErrorSource::custom(format!("expected object found {v}")),
kind: patch_db::value::ErrorKind::Serialization,
}
.into()),
}
}
}

View File

@@ -0,0 +1,363 @@
use std::collections::BTreeMap;
use std::time::Duration;
use color_eyre::eyre::eyre;
use emver::VersionRange;
use models::OptionExt;
use rand::SeedableRng;
use rpc_toolkit::command;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use crate::config::action::ConfigRes;
use crate::config::spec::PackagePointerSpec;
use crate::config::{not_found, Config, ConfigSpec, ConfigureContext};
use crate::context::RpcContext;
use crate::db::model::{CurrentDependencies, Database};
use crate::prelude::*;
use crate::procedure::{NoOutput, PackageProcedure, ProcedureName};
use crate::s9pk::manifest::{Manifest, PackageId};
use crate::status::DependencyConfigErrors;
use crate::util::serde::display_serializable;
use crate::util::{display_none, Version};
use crate::volume::Volumes;
use crate::Error;
#[command(subcommands(configure))]
pub fn dependency() -> Result<(), Error> {
Ok(())
}
#[derive(Clone, Debug, Default, Deserialize, Serialize, HasModel)]
#[model = "Model<Self>"]
pub struct Dependencies(pub BTreeMap<PackageId, DepInfo>);
impl Map for Dependencies {
type Key = PackageId;
type Value = DepInfo;
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
#[serde(tag = "type")]
pub enum DependencyRequirement {
OptIn { how: String },
OptOut { how: String },
Required,
}
impl DependencyRequirement {
pub fn required(&self) -> bool {
matches!(self, &DependencyRequirement::Required)
}
}
#[derive(Clone, Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct DepInfo {
pub version: VersionRange,
pub requirement: DependencyRequirement,
pub description: Option<String>,
#[serde(default)]
pub config: Option<DependencyConfig>,
}
#[derive(Clone, Debug, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct DependencyConfig {
check: PackageProcedure,
auto_configure: PackageProcedure,
}
impl DependencyConfig {
pub async fn check(
&self,
ctx: &RpcContext,
dependent_id: &PackageId,
dependent_version: &Version,
dependent_volumes: &Volumes,
dependency_id: &PackageId,
dependency_config: &Config,
) -> Result<Result<NoOutput, String>, Error> {
Ok(self
.check
.sandboxed(
ctx,
dependent_id,
dependent_version,
dependent_volumes,
Some(dependency_config),
None,
ProcedureName::Check(dependency_id.clone()),
)
.await?
.map_err(|(_, e)| e))
}
pub async fn auto_configure(
&self,
ctx: &RpcContext,
dependent_id: &PackageId,
dependent_version: &Version,
dependent_volumes: &Volumes,
old: &Config,
) -> Result<Config, Error> {
self.auto_configure
.sandboxed(
ctx,
dependent_id,
dependent_version,
dependent_volumes,
Some(old),
None,
ProcedureName::AutoConfig(dependent_id.clone()),
)
.await?
.map_err(|e| Error::new(eyre!("{}", e.1), crate::ErrorKind::AutoConfigure))
}
}
#[command(
subcommands(self(configure_impl(async)), configure_dry),
display(display_none)
)]
pub async fn configure(
#[arg(rename = "dependent-id")] dependent_id: PackageId,
#[arg(rename = "dependency-id")] dependency_id: PackageId,
) -> Result<(PackageId, PackageId), Error> {
Ok((dependent_id, dependency_id))
}
pub async fn configure_impl(
ctx: RpcContext,
(pkg_id, dep_id): (PackageId, PackageId),
) -> Result<(), Error> {
let breakages = BTreeMap::new();
let overrides = Default::default();
let ConfigDryRes {
old_config: _,
new_config,
spec: _,
} = configure_logic(ctx.clone(), (pkg_id, dep_id.clone())).await?;
let configure_context = ConfigureContext {
breakages,
timeout: Some(Duration::from_secs(3).into()),
config: Some(new_config),
dry_run: false,
overrides,
};
crate::config::configure(&ctx, &dep_id, configure_context).await?;
Ok(())
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct ConfigDryRes {
pub old_config: Config,
pub new_config: Config,
pub spec: ConfigSpec,
}
#[command(rename = "dry", display(display_serializable))]
#[instrument(skip_all)]
pub async fn configure_dry(
#[context] ctx: RpcContext,
#[parent_data] (pkg_id, dependency_id): (PackageId, PackageId),
) -> Result<ConfigDryRes, Error> {
configure_logic(ctx, (pkg_id, dependency_id)).await
}
pub async fn configure_logic(
ctx: RpcContext,
(pkg_id, dependency_id): (PackageId, PackageId),
) -> Result<ConfigDryRes, Error> {
let db = ctx.db.peek().await;
let pkg = db
.as_package_data()
.as_idx(&pkg_id)
.or_not_found(&pkg_id)?
.as_installed()
.or_not_found(&pkg_id)?;
let pkg_version = pkg.as_manifest().as_version().de()?;
let pkg_volumes = pkg.as_manifest().as_volumes().de()?;
let dependency = db
.as_package_data()
.as_idx(&dependency_id)
.or_not_found(&dependency_id)?
.as_installed()
.or_not_found(&dependency_id)?;
let dependency_config_action = dependency
.as_manifest()
.as_config()
.de()?
.ok_or_else(|| not_found!("Manifest Config"))?;
let dependency_version = dependency.as_manifest().as_version().de()?;
let dependency_volumes = dependency.as_manifest().as_volumes().de()?;
let dependency = pkg
.as_manifest()
.as_dependencies()
.as_idx(&dependency_id)
.or_not_found(&dependency_id)?;
let ConfigRes {
config: maybe_config,
spec,
} = dependency_config_action
.get(
&ctx,
&dependency_id,
&dependency_version,
&dependency_volumes,
)
.await?;
let old_config = if let Some(config) = maybe_config {
config
} else {
spec.gen(
&mut rand::rngs::StdRng::from_entropy(),
&Some(Duration::new(10, 0)),
)?
};
let new_config = dependency
.as_config()
.de()?
.ok_or_else(|| not_found!("Config"))?
.auto_configure
.sandboxed(
&ctx,
&pkg_id,
&pkg_version,
&pkg_volumes,
Some(&old_config),
None,
ProcedureName::AutoConfig(dependency_id.clone()),
)
.await?
.map_err(|e| Error::new(eyre!("{}", e.1), crate::ErrorKind::AutoConfigure))?;
Ok(ConfigDryRes {
old_config,
new_config,
spec,
})
}
#[instrument(skip_all)]
pub fn add_dependent_to_current_dependents_lists(
db: &mut Model<Database>,
dependent_id: &PackageId,
current_dependencies: &CurrentDependencies,
) -> Result<(), Error> {
for (dependency, dep_info) in &current_dependencies.0 {
if let Some(dependency_dependents) = db
.as_package_data_mut()
.as_idx_mut(dependency)
.and_then(|pde| pde.as_installed_mut())
.map(|i| i.as_current_dependents_mut())
{
dependency_dependents.insert(dependent_id, dep_info)?;
}
}
Ok(())
}
pub fn set_dependents_with_live_pointers_to_needs_config(
db: &mut Peeked,
id: &PackageId,
) -> Result<Vec<(PackageId, Version)>, Error> {
let mut res = Vec::new();
for (dep, info) in db
.as_package_data()
.as_idx(id)
.or_not_found(id)?
.as_installed()
.or_not_found(id)?
.as_current_dependents()
.de()?
.0
{
if info.pointers.iter().any(|ptr| match ptr {
// dependency id matches the package being uninstalled
PackagePointerSpec::TorAddress(ptr) => &ptr.package_id == id && &dep != id,
PackagePointerSpec::LanAddress(ptr) => &ptr.package_id == id && &dep != id,
// we never need to retarget these
PackagePointerSpec::TorKey(_) => false,
PackagePointerSpec::Config(_) => false,
}) {
let installed = db
.as_package_data_mut()
.as_idx_mut(&dep)
.or_not_found(&dep)?
.as_installed_mut()
.or_not_found(&dep)?;
let version = installed.as_manifest().as_version().de()?;
let configured = installed.as_status_mut().as_configured_mut();
if configured.de()? {
configured.ser(&false)?;
res.push((dep, version));
}
}
}
Ok(res)
}
#[instrument(skip_all)]
pub async fn compute_dependency_config_errs(
ctx: &RpcContext,
db: &Peeked,
manifest: &Manifest,
current_dependencies: &CurrentDependencies,
dependency_config: &BTreeMap<PackageId, Config>,
) -> Result<DependencyConfigErrors, Error> {
let mut dependency_config_errs = BTreeMap::new();
for (dependency, _dep_info) in current_dependencies
.0
.iter()
.filter(|(dep_id, _)| dep_id != &&manifest.id)
{
// check if config passes dependency check
if let Some(cfg) = &manifest
.dependencies
.0
.get(dependency)
.or_not_found(dependency)?
.config
{
if let Err(error) = cfg
.check(
ctx,
&manifest.id,
&manifest.version,
&manifest.volumes,
dependency,
&if let Some(config) = dependency_config.get(dependency) {
config.clone()
} else if let Some(manifest) = db
.as_package_data()
.as_idx(dependency)
.and_then(|pde| pde.as_installed())
.map(|i| i.as_manifest().de())
.transpose()?
{
if let Some(config) = &manifest.config {
config
.get(ctx, &manifest.id, &manifest.version, &manifest.volumes)
.await?
.config
.unwrap_or_default()
} else {
Config::default()
}
} else {
Config::default()
},
)
.await?
{
dependency_config_errs.insert(dependency.clone(), error);
}
}
}
Ok(DependencyConfigErrors(dependency_config_errs))
}

View File

@@ -0,0 +1,55 @@
use std::fs::File;
use std::io::Write;
use std::path::Path;
use ed25519::pkcs8::EncodePrivateKey;
use ed25519::PublicKeyBytes;
use ed25519_dalek::{SigningKey, VerifyingKey};
use rpc_toolkit::command;
use tracing::instrument;
use crate::context::SdkContext;
use crate::util::display_none;
use crate::{Error, ResultExt};
#[command(cli_only, blocking, display(display_none))]
#[instrument(skip_all)]
pub fn init(#[context] ctx: SdkContext) -> Result<(), Error> {
if !ctx.developer_key_path.exists() {
let parent = ctx.developer_key_path.parent().unwrap_or(Path::new("/"));
if !parent.exists() {
std::fs::create_dir_all(parent)
.with_ctx(|_| (crate::ErrorKind::Filesystem, parent.display().to_string()))?;
}
tracing::info!("Generating new developer key...");
let secret = SigningKey::generate(&mut rand::thread_rng());
tracing::info!("Writing key to {}", ctx.developer_key_path.display());
let keypair_bytes = ed25519::KeypairBytes {
secret_key: secret.to_bytes(),
public_key: Some(PublicKeyBytes(VerifyingKey::from(&secret).to_bytes())),
};
let mut dev_key_file = File::create(&ctx.developer_key_path)?;
dev_key_file.write_all(
keypair_bytes
.to_pkcs8_pem(base64ct::LineEnding::default())
.with_kind(crate::ErrorKind::Pem)?
.as_bytes(),
)?;
dev_key_file.sync_all()?;
println!(
"New developer key generated at {}",
ctx.developer_key_path.display()
);
} else {
println!(
"Developer key already exists at {}",
ctx.developer_key_path.display()
);
}
Ok(())
}
#[command(subcommands(crate::s9pk::verify, crate::config::verify_spec))]
pub fn verify() -> Result<(), Error> {
Ok(())
}

View File

@@ -0,0 +1,72 @@
use std::path::Path;
use std::sync::Arc;
use rpc_toolkit::command;
use rpc_toolkit::yajrc::RpcError;
use crate::context::DiagnosticContext;
use crate::disk::repair;
use crate::init::SYSTEM_REBUILD_PATH;
use crate::logs::{fetch_logs, LogResponse, LogSource};
use crate::shutdown::Shutdown;
use crate::util::display_none;
use crate::Error;
#[command(subcommands(error, logs, exit, restart, forget_disk, disk, rebuild))]
pub fn diagnostic() -> Result<(), Error> {
Ok(())
}
#[command]
pub fn error(#[context] ctx: DiagnosticContext) -> Result<Arc<RpcError>, Error> {
Ok(ctx.error.clone())
}
#[command(rpc_only)]
pub async fn logs(
#[arg] limit: Option<usize>,
#[arg] cursor: Option<String>,
#[arg] before: bool,
) -> Result<LogResponse, Error> {
Ok(fetch_logs(LogSource::System, limit, cursor, before).await?)
}
#[command(display(display_none))]
pub fn exit(#[context] ctx: DiagnosticContext) -> Result<(), Error> {
ctx.shutdown.send(None).expect("receiver dropped");
Ok(())
}
#[command(display(display_none))]
pub fn restart(#[context] ctx: DiagnosticContext) -> Result<(), Error> {
ctx.shutdown
.send(Some(Shutdown {
export_args: ctx
.disk_guid
.clone()
.map(|guid| (guid, ctx.datadir.clone())),
restart: true,
}))
.expect("receiver dropped");
Ok(())
}
#[command(display(display_none))]
pub async fn rebuild(#[context] ctx: DiagnosticContext) -> Result<(), Error> {
tokio::fs::write(SYSTEM_REBUILD_PATH, b"").await?;
restart(ctx)
}
#[command(subcommands(forget_disk, repair))]
pub fn disk() -> Result<(), Error> {
Ok(())
}
#[command(rename = "forget", display(display_none))]
pub async fn forget_disk() -> Result<(), Error> {
let disk_guid = Path::new("/media/embassy/config/disk.guid");
if tokio::fs::metadata(disk_guid).await.is_ok() {
tokio::fs::remove_file(disk_guid).await?;
}
Ok(())
}

View File

@@ -0,0 +1,32 @@
use std::path::Path;
use tokio::process::Command;
use tracing::instrument;
use crate::disk::fsck::RequiresReboot;
use crate::util::Invoke;
use crate::Error;
#[instrument(skip_all)]
pub async fn btrfs_check_readonly(logicalname: impl AsRef<Path>) -> Result<RequiresReboot, Error> {
Command::new("btrfs")
.arg("check")
.arg("--readonly")
.arg(logicalname.as_ref())
.invoke(crate::ErrorKind::DiskManagement)
.await?;
Ok(RequiresReboot(false))
}
pub async fn btrfs_check_repair(logicalname: impl AsRef<Path>) -> Result<RequiresReboot, Error> {
Command::new("btrfs")
.arg("check")
.arg("--repair")
.arg("--force")
.arg(logicalname.as_ref())
.invoke(crate::ErrorKind::DiskManagement)
.await?;
Ok(RequiresReboot(false))
}

View File

@@ -0,0 +1,95 @@
use std::ffi::OsStr;
use std::path::Path;
use color_eyre::eyre::eyre;
use futures::future::BoxFuture;
use futures::FutureExt;
use tokio::process::Command;
use tracing::instrument;
use crate::disk::fsck::RequiresReboot;
use crate::Error;
#[instrument(skip_all)]
pub async fn e2fsck_preen(
logicalname: impl AsRef<Path> + std::fmt::Debug,
) -> Result<RequiresReboot, Error> {
e2fsck_runner(Command::new("e2fsck").arg("-p"), logicalname).await
}
fn backup_existing_undo_file<'a>(path: &'a Path) -> BoxFuture<'a, Result<(), Error>> {
async move {
if tokio::fs::metadata(path).await.is_ok() {
let bak = path.with_extension(format!(
"{}.bak",
path.extension()
.and_then(|s| s.to_str())
.unwrap_or_default()
));
backup_existing_undo_file(&bak).await?;
tokio::fs::rename(path, &bak).await?;
}
Ok(())
}
.boxed()
}
#[instrument(skip_all)]
pub async fn e2fsck_aggressive(
logicalname: impl AsRef<Path> + std::fmt::Debug,
) -> Result<RequiresReboot, Error> {
let undo_path = Path::new("/media/embassy/config")
.join(
logicalname
.as_ref()
.file_name()
.unwrap_or(OsStr::new("unknown")),
)
.with_extension("e2undo");
backup_existing_undo_file(&undo_path).await?;
e2fsck_runner(
Command::new("e2fsck").arg("-y").arg("-z").arg(undo_path),
logicalname,
)
.await
}
async fn e2fsck_runner(
e2fsck_cmd: &mut Command,
logicalname: impl AsRef<Path> + std::fmt::Debug,
) -> Result<RequiresReboot, Error> {
let e2fsck_out = e2fsck_cmd.arg(logicalname.as_ref()).output().await?;
let e2fsck_stderr = String::from_utf8(e2fsck_out.stderr)?;
let code = e2fsck_out.status.code().ok_or_else(|| {
Error::new(
eyre!("e2fsck: process terminated by signal"),
crate::ErrorKind::DiskManagement,
)
})?;
if code & 4 != 0 {
tracing::error!(
"some filesystem errors NOT corrected on {}:\n{}",
logicalname.as_ref().display(),
e2fsck_stderr,
);
} else if code & 1 != 0 {
tracing::warn!(
"filesystem errors corrected on {}:\n{}",
logicalname.as_ref().display(),
e2fsck_stderr,
);
}
if code < 8 {
if code & 2 != 0 {
tracing::warn!("reboot required");
Ok(RequiresReboot(true))
} else {
Ok(RequiresReboot(false))
}
} else {
Err(Error::new(
eyre!("e2fsck: {}", e2fsck_stderr),
crate::ErrorKind::DiskManagement,
))
}
}

View File

@@ -0,0 +1,70 @@
use std::path::Path;
use color_eyre::eyre::eyre;
use tokio::process::Command;
use crate::disk::fsck::btrfs::{btrfs_check_readonly, btrfs_check_repair};
use crate::disk::fsck::ext4::{e2fsck_aggressive, e2fsck_preen};
use crate::util::Invoke;
use crate::Error;
pub mod btrfs;
pub mod ext4;
#[derive(Debug, Clone, Copy)]
#[must_use]
pub struct RequiresReboot(pub bool);
impl std::ops::BitOrAssign for RequiresReboot {
fn bitor_assign(&mut self, rhs: Self) {
self.0 |= rhs.0
}
}
#[derive(Debug, Clone, Copy)]
pub enum RepairStrategy {
Preen,
Aggressive,
}
impl RepairStrategy {
pub async fn fsck(
&self,
logicalname: impl AsRef<Path> + std::fmt::Debug,
) -> Result<RequiresReboot, Error> {
match &*String::from_utf8(
Command::new("grub-probe")
.arg("-d")
.arg(logicalname.as_ref())
.invoke(crate::ErrorKind::DiskManagement)
.await?,
)?
.trim()
{
"ext2" => self.e2fsck(logicalname).await,
"btrfs" => self.btrfs_check(logicalname).await,
fs => {
return Err(Error::new(
eyre!("Unknown filesystem {fs}"),
crate::ErrorKind::DiskManagement,
))
}
}
}
pub async fn e2fsck(
&self,
logicalname: impl AsRef<Path> + std::fmt::Debug,
) -> Result<RequiresReboot, Error> {
match self {
RepairStrategy::Preen => e2fsck_preen(logicalname).await,
RepairStrategy::Aggressive => e2fsck_aggressive(logicalname).await,
}
}
pub async fn btrfs_check(
&self,
logicalname: impl AsRef<Path> + std::fmt::Debug,
) -> Result<RequiresReboot, Error> {
match self {
RepairStrategy::Preen => btrfs_check_readonly(logicalname).await,
RepairStrategy::Aggressive => btrfs_check_repair(logicalname).await,
}
}
}

View File

@@ -0,0 +1,337 @@
use std::collections::BTreeMap;
use std::path::{Path, PathBuf};
use color_eyre::eyre::eyre;
use tokio::process::Command;
use tracing::instrument;
use super::fsck::{RepairStrategy, RequiresReboot};
use super::util::pvscan;
use crate::disk::mount::filesystem::block_dev::mount;
use crate::disk::mount::filesystem::ReadWrite;
use crate::disk::mount::util::unmount;
use crate::util::Invoke;
use crate::{Error, ErrorKind, ResultExt};
pub const PASSWORD_PATH: &'static str = "/run/embassy/password";
pub const DEFAULT_PASSWORD: &'static str = "password";
pub const MAIN_FS_SIZE: FsSize = FsSize::Gigabytes(8);
#[instrument(skip_all)]
pub async fn create<I, P>(
disks: &I,
pvscan: &BTreeMap<PathBuf, Option<String>>,
datadir: impl AsRef<Path>,
password: Option<&str>,
) -> Result<String, Error>
where
for<'a> &'a I: IntoIterator<Item = &'a P>,
P: AsRef<Path>,
{
let guid = create_pool(disks, pvscan, password.is_some()).await?;
create_all_fs(&guid, &datadir, password).await?;
export(&guid, datadir).await?;
Ok(guid)
}
#[instrument(skip_all)]
pub async fn create_pool<I, P>(
disks: &I,
pvscan: &BTreeMap<PathBuf, Option<String>>,
encrypted: bool,
) -> Result<String, Error>
where
for<'a> &'a I: IntoIterator<Item = &'a P>,
P: AsRef<Path>,
{
Command::new("dmsetup")
.arg("remove_all") // TODO: find a higher finesse way to do this for portability reasons
.invoke(crate::ErrorKind::DiskManagement)
.await?;
for disk in disks {
if pvscan.contains_key(disk.as_ref()) {
Command::new("pvremove")
.arg("-yff")
.arg(disk.as_ref())
.invoke(crate::ErrorKind::DiskManagement)
.await?;
}
tokio::fs::write(disk.as_ref(), &[0; 2048]).await?; // wipe partition table
Command::new("pvcreate")
.arg("-yff")
.arg(disk.as_ref())
.invoke(crate::ErrorKind::DiskManagement)
.await?;
}
let mut guid = format!(
"EMBASSY_{}",
base32::encode(
base32::Alphabet::RFC4648 { padding: false },
&rand::random::<[u8; 32]>(),
)
);
if !encrypted {
guid += "_UNENC";
}
let mut cmd = Command::new("vgcreate");
cmd.arg("-y").arg(&guid);
for disk in disks {
cmd.arg(disk.as_ref());
}
cmd.invoke(crate::ErrorKind::DiskManagement).await?;
Ok(guid)
}
#[derive(Debug, Clone, Copy)]
pub enum FsSize {
Gigabytes(usize),
FreePercentage(usize),
}
#[instrument(skip_all)]
pub async fn create_fs<P: AsRef<Path>>(
guid: &str,
datadir: P,
name: &str,
size: FsSize,
password: Option<&str>,
) -> Result<(), Error> {
let mut cmd = Command::new("lvcreate");
match size {
FsSize::Gigabytes(a) => cmd.arg("-L").arg(format!("{}G", a)),
FsSize::FreePercentage(a) => cmd.arg("-l").arg(format!("{}%FREE", a)),
};
cmd.arg("-y")
.arg("-n")
.arg(name)
.arg(guid)
.invoke(crate::ErrorKind::DiskManagement)
.await?;
let mut blockdev_path = Path::new("/dev").join(guid).join(name);
if let Some(password) = password {
if let Some(parent) = Path::new(PASSWORD_PATH).parent() {
tokio::fs::create_dir_all(parent).await?;
}
tokio::fs::write(PASSWORD_PATH, password)
.await
.with_ctx(|_| (crate::ErrorKind::Filesystem, PASSWORD_PATH))?;
Command::new("cryptsetup")
.arg("-q")
.arg("luksFormat")
.arg(format!("--key-file={}", PASSWORD_PATH))
.arg(format!("--keyfile-size={}", password.len()))
.arg(&blockdev_path)
.invoke(crate::ErrorKind::DiskManagement)
.await?;
Command::new("cryptsetup")
.arg("-q")
.arg("luksOpen")
.arg("--allow-discards")
.arg(format!("--key-file={}", PASSWORD_PATH))
.arg(format!("--keyfile-size={}", password.len()))
.arg(&blockdev_path)
.arg(format!("{}_{}", guid, name))
.invoke(crate::ErrorKind::DiskManagement)
.await?;
tokio::fs::remove_file(PASSWORD_PATH)
.await
.with_ctx(|_| (crate::ErrorKind::Filesystem, PASSWORD_PATH))?;
blockdev_path = Path::new("/dev/mapper").join(format!("{}_{}", guid, name));
}
Command::new("mkfs.btrfs")
.arg(&blockdev_path)
.invoke(crate::ErrorKind::DiskManagement)
.await?;
mount(&blockdev_path, datadir.as_ref().join(name), ReadWrite).await?;
Ok(())
}
#[instrument(skip_all)]
pub async fn create_all_fs<P: AsRef<Path>>(
guid: &str,
datadir: P,
password: Option<&str>,
) -> Result<(), Error> {
create_fs(guid, &datadir, "main", MAIN_FS_SIZE, password).await?;
create_fs(
guid,
&datadir,
"package-data",
FsSize::FreePercentage(100),
password,
)
.await?;
Ok(())
}
#[instrument(skip_all)]
pub async fn unmount_fs<P: AsRef<Path>>(guid: &str, datadir: P, name: &str) -> Result<(), Error> {
unmount(datadir.as_ref().join(name)).await?;
if !guid.ends_with("_UNENC") {
Command::new("cryptsetup")
.arg("-q")
.arg("luksClose")
.arg(format!("{}_{}", guid, name))
.invoke(crate::ErrorKind::DiskManagement)
.await?;
}
Ok(())
}
#[instrument(skip_all)]
pub async fn unmount_all_fs<P: AsRef<Path>>(guid: &str, datadir: P) -> Result<(), Error> {
unmount_fs(guid, &datadir, "main").await?;
unmount_fs(guid, &datadir, "package-data").await?;
Command::new("dmsetup")
.arg("remove_all") // TODO: find a higher finesse way to do this for portability reasons
.invoke(crate::ErrorKind::DiskManagement)
.await?;
Ok(())
}
#[instrument(skip_all)]
pub async fn export<P: AsRef<Path>>(guid: &str, datadir: P) -> Result<(), Error> {
Command::new("sync").invoke(ErrorKind::Filesystem).await?;
unmount_all_fs(guid, datadir).await?;
Command::new("vgchange")
.arg("-an")
.arg(guid)
.invoke(crate::ErrorKind::DiskManagement)
.await?;
Command::new("vgexport")
.arg(guid)
.invoke(crate::ErrorKind::DiskManagement)
.await?;
Ok(())
}
#[instrument(skip_all)]
pub async fn import<P: AsRef<Path>>(
guid: &str,
datadir: P,
repair: RepairStrategy,
password: Option<&str>,
) -> Result<RequiresReboot, Error> {
let scan = pvscan().await?;
if scan
.values()
.filter_map(|a| a.as_ref())
.filter(|a| a.starts_with("EMBASSY_"))
.next()
.is_none()
{
return Err(Error::new(
eyre!("StartOS disk not found."),
crate::ErrorKind::DiskNotAvailable,
));
}
if !scan
.values()
.filter_map(|a| a.as_ref())
.any(|id| id == guid)
{
return Err(Error::new(
eyre!("A StartOS disk was found, but it is not the correct disk for this device."),
crate::ErrorKind::IncorrectDisk,
));
}
Command::new("dmsetup")
.arg("remove_all") // TODO: find a higher finesse way to do this for portability reasons
.invoke(crate::ErrorKind::DiskManagement)
.await?;
match Command::new("vgimport")
.arg(guid)
.invoke(crate::ErrorKind::DiskManagement)
.await
{
Ok(_) => Ok(()),
Err(e)
if format!("{}", e.source)
.lines()
.any(|l| l.trim() == format!("Volume group \"{}\" is not exported", guid)) =>
{
Ok(())
}
Err(e) => Err(e),
}?;
Command::new("vgchange")
.arg("-ay")
.arg(guid)
.invoke(crate::ErrorKind::DiskManagement)
.await?;
mount_all_fs(guid, datadir, repair, password).await
}
#[instrument(skip_all)]
pub async fn mount_fs<P: AsRef<Path>>(
guid: &str,
datadir: P,
name: &str,
repair: RepairStrategy,
password: Option<&str>,
) -> Result<RequiresReboot, Error> {
let orig_path = Path::new("/dev").join(guid).join(name);
let mut blockdev_path = orig_path.clone();
let full_name = format!("{}_{}", guid, name);
if !guid.ends_with("_UNENC") {
let password = password.unwrap_or(DEFAULT_PASSWORD);
if let Some(parent) = Path::new(PASSWORD_PATH).parent() {
tokio::fs::create_dir_all(parent).await?;
}
tokio::fs::write(PASSWORD_PATH, password)
.await
.with_ctx(|_| (crate::ErrorKind::Filesystem, PASSWORD_PATH))?;
Command::new("cryptsetup")
.arg("-q")
.arg("luksOpen")
.arg(format!("--key-file={}", PASSWORD_PATH))
.arg(format!("--keyfile-size={}", password.len()))
.arg(&blockdev_path)
.arg(&full_name)
.invoke(crate::ErrorKind::DiskManagement)
.await?;
tokio::fs::remove_file(PASSWORD_PATH)
.await
.with_ctx(|_| (crate::ErrorKind::Filesystem, PASSWORD_PATH))?;
blockdev_path = Path::new("/dev/mapper").join(&full_name);
}
let reboot = repair.fsck(&blockdev_path).await?;
if !guid.ends_with("_UNENC") {
// Backup LUKS header if e2fsck succeeded
let luks_folder = Path::new("/media/embassy/config/luks");
tokio::fs::create_dir_all(luks_folder).await?;
let tmp_luks_bak = luks_folder.join(format!(".{full_name}.luks.bak.tmp"));
if tokio::fs::metadata(&tmp_luks_bak).await.is_ok() {
tokio::fs::remove_file(&tmp_luks_bak).await?;
}
let luks_bak = luks_folder.join(format!("{full_name}.luks.bak"));
Command::new("cryptsetup")
.arg("-q")
.arg("luksHeaderBackup")
.arg("--header-backup-file")
.arg(&tmp_luks_bak)
.arg(&orig_path)
.invoke(crate::ErrorKind::DiskManagement)
.await?;
tokio::fs::rename(&tmp_luks_bak, &luks_bak).await?;
}
mount(&blockdev_path, datadir.as_ref().join(name), ReadWrite).await?;
Ok(reboot)
}
#[instrument(skip_all)]
pub async fn mount_all_fs<P: AsRef<Path>>(
guid: &str,
datadir: P,
repair: RepairStrategy,
password: Option<&str>,
) -> Result<RequiresReboot, Error> {
let mut reboot = RequiresReboot(false);
reboot |= mount_fs(guid, &datadir, "main", repair, password).await?;
reboot |= mount_fs(guid, &datadir, "package-data", repair, password).await?;
Ok(reboot)
}

View File

@@ -0,0 +1,118 @@
use std::path::{Path, PathBuf};
use clap::ArgMatches;
use rpc_toolkit::command;
use serde::{Deserialize, Serialize};
use crate::context::RpcContext;
use crate::disk::util::DiskInfo;
use crate::util::display_none;
use crate::util::serde::{display_serializable, IoFormat};
use crate::Error;
pub mod fsck;
pub mod main;
pub mod mount;
pub mod util;
pub const BOOT_RW_PATH: &str = "/media/boot-rw";
pub const REPAIR_DISK_PATH: &str = "/media/embassy/config/repair-disk";
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct OsPartitionInfo {
pub efi: Option<PathBuf>,
pub bios: Option<PathBuf>,
pub boot: PathBuf,
pub root: PathBuf,
}
impl OsPartitionInfo {
pub fn contains(&self, logicalname: impl AsRef<Path>) -> bool {
self.efi
.as_ref()
.map(|p| p == logicalname.as_ref())
.unwrap_or(false)
|| self
.bios
.as_ref()
.map(|p| p == logicalname.as_ref())
.unwrap_or(false)
|| &*self.boot == logicalname.as_ref()
|| &*self.root == logicalname.as_ref()
}
}
#[command(subcommands(list, repair))]
pub fn disk() -> Result<(), Error> {
Ok(())
}
fn display_disk_info(info: Vec<DiskInfo>, matches: &ArgMatches) {
use prettytable::*;
if matches.is_present("format") {
return display_serializable(info, matches);
}
let mut table = Table::new();
table.add_row(row![bc =>
"LOGICALNAME",
"LABEL",
"CAPACITY",
"USED",
"EMBASSY OS VERSION"
]);
for disk in info {
let row = row![
disk.logicalname.display(),
"N/A",
&format!("{:.2} GiB", disk.capacity as f64 / 1024.0 / 1024.0 / 1024.0),
"N/A",
"N/A",
];
table.add_row(row);
for part in disk.partitions {
let row = row![
part.logicalname.display(),
if let Some(label) = part.label.as_ref() {
label
} else {
"N/A"
},
part.capacity,
if let Some(used) = part
.used
.map(|u| format!("{:.2} GiB", u as f64 / 1024.0 / 1024.0 / 1024.0))
.as_ref()
{
used
} else {
"N/A"
},
if let Some(eos) = part.embassy_os.as_ref() {
eos.version.as_str()
} else {
"N/A"
},
];
table.add_row(row);
}
}
table.print_tty(false).unwrap();
}
#[command(display(display_disk_info))]
pub async fn list(
#[context] ctx: RpcContext,
#[allow(unused_variables)]
#[arg]
format: Option<IoFormat>,
) -> Result<Vec<DiskInfo>, Error> {
crate::disk::util::list(&ctx.os_partitions).await
}
#[command(display(display_none))]
pub async fn repair() -> Result<(), Error> {
tokio::fs::write(REPAIR_DISK_PATH, b"").await?;
Ok(())
}

View File

@@ -0,0 +1,262 @@
use std::path::{Path, PathBuf};
use color_eyre::eyre::eyre;
use helpers::AtomicFile;
use tokio::io::AsyncWriteExt;
use tracing::instrument;
use super::filesystem::ecryptfs::EcryptFS;
use super::guard::{GenericMountGuard, TmpMountGuard};
use super::util::{bind, unmount};
use crate::auth::check_password;
use crate::backup::target::BackupInfo;
use crate::disk::mount::filesystem::ReadWrite;
use crate::disk::util::EmbassyOsRecoveryInfo;
use crate::middleware::encrypt::{decrypt_slice, encrypt_slice};
use crate::s9pk::manifest::PackageId;
use crate::util::serde::IoFormat;
use crate::util::FileLock;
use crate::volume::BACKUP_DIR;
use crate::{Error, ErrorKind, ResultExt};
pub struct BackupMountGuard<G: GenericMountGuard> {
backup_disk_mount_guard: Option<G>,
encrypted_guard: Option<TmpMountGuard>,
enc_key: String,
pub unencrypted_metadata: EmbassyOsRecoveryInfo,
pub metadata: BackupInfo,
}
impl<G: GenericMountGuard> BackupMountGuard<G> {
fn backup_disk_path(&self) -> &Path {
if let Some(guard) = &self.backup_disk_mount_guard {
guard.as_ref()
} else {
unreachable!()
}
}
#[instrument(skip_all)]
pub async fn mount(backup_disk_mount_guard: G, password: &str) -> Result<Self, Error> {
let backup_disk_path = backup_disk_mount_guard.as_ref();
let unencrypted_metadata_path =
backup_disk_path.join("EmbassyBackups/unencrypted-metadata.cbor");
let mut unencrypted_metadata: EmbassyOsRecoveryInfo =
if tokio::fs::metadata(&unencrypted_metadata_path)
.await
.is_ok()
{
IoFormat::Cbor.from_slice(
&tokio::fs::read(&unencrypted_metadata_path)
.await
.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
unencrypted_metadata_path.display().to_string(),
)
})?,
)?
} else {
Default::default()
};
let enc_key = if let (Some(hash), Some(wrapped_key)) = (
unencrypted_metadata.password_hash.as_ref(),
unencrypted_metadata.wrapped_key.as_ref(),
) {
let wrapped_key =
base32::decode(base32::Alphabet::RFC4648 { padding: true }, wrapped_key)
.ok_or_else(|| {
Error::new(
eyre!("failed to decode wrapped key"),
crate::ErrorKind::Backup,
)
})?;
check_password(hash, password)?;
String::from_utf8(decrypt_slice(wrapped_key, password))?
} else {
base32::encode(
base32::Alphabet::RFC4648 { padding: false },
&rand::random::<[u8; 32]>()[..],
)
};
if unencrypted_metadata.password_hash.is_none() {
unencrypted_metadata.password_hash = Some(
argon2::hash_encoded(
password.as_bytes(),
&rand::random::<[u8; 16]>()[..],
&argon2::Config::rfc9106_low_mem(),
)
.with_kind(crate::ErrorKind::PasswordHashGeneration)?,
);
}
if unencrypted_metadata.wrapped_key.is_none() {
unencrypted_metadata.wrapped_key = Some(base32::encode(
base32::Alphabet::RFC4648 { padding: true },
&encrypt_slice(&enc_key, password),
));
}
let crypt_path = backup_disk_path.join("EmbassyBackups/crypt");
if tokio::fs::metadata(&crypt_path).await.is_err() {
tokio::fs::create_dir_all(&crypt_path).await.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
crypt_path.display().to_string(),
)
})?;
}
let encrypted_guard =
TmpMountGuard::mount(&EcryptFS::new(&crypt_path, &enc_key), ReadWrite).await?;
let metadata_path = encrypted_guard.as_ref().join("metadata.cbor");
let metadata: BackupInfo = if tokio::fs::metadata(&metadata_path).await.is_ok() {
IoFormat::Cbor.from_slice(&tokio::fs::read(&metadata_path).await.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
metadata_path.display().to_string(),
)
})?)?
} else {
Default::default()
};
Ok(Self {
backup_disk_mount_guard: Some(backup_disk_mount_guard),
encrypted_guard: Some(encrypted_guard),
enc_key,
unencrypted_metadata,
metadata,
})
}
pub fn change_password(&mut self, new_password: &str) -> Result<(), Error> {
self.unencrypted_metadata.password_hash = Some(
argon2::hash_encoded(
new_password.as_bytes(),
&rand::random::<[u8; 16]>()[..],
&argon2::Config::rfc9106_low_mem(),
)
.with_kind(crate::ErrorKind::PasswordHashGeneration)?,
);
self.unencrypted_metadata.wrapped_key = Some(base32::encode(
base32::Alphabet::RFC4648 { padding: false },
&encrypt_slice(&self.enc_key, new_password),
));
Ok(())
}
#[instrument(skip_all)]
pub async fn mount_package_backup(
&self,
id: &PackageId,
) -> Result<PackageBackupMountGuard, Error> {
let lock = FileLock::new(Path::new(BACKUP_DIR).join(format!("{}.lock", id)), false).await?;
let mountpoint = Path::new(BACKUP_DIR).join(id);
bind(self.as_ref().join(id), &mountpoint, false).await?;
Ok(PackageBackupMountGuard {
mountpoint: Some(mountpoint),
lock: Some(lock),
})
}
#[instrument(skip_all)]
pub async fn save(&self) -> Result<(), Error> {
let metadata_path = self.as_ref().join("metadata.cbor");
let backup_disk_path = self.backup_disk_path();
let mut file = AtomicFile::new(&metadata_path, None::<PathBuf>)
.await
.with_kind(ErrorKind::Filesystem)?;
file.write_all(&IoFormat::Cbor.to_vec(&self.metadata)?)
.await?;
file.save().await.with_kind(ErrorKind::Filesystem)?;
let unencrypted_metadata_path =
backup_disk_path.join("EmbassyBackups/unencrypted-metadata.cbor");
let mut file = AtomicFile::new(&unencrypted_metadata_path, None::<PathBuf>)
.await
.with_kind(ErrorKind::Filesystem)?;
file.write_all(&IoFormat::Cbor.to_vec(&self.unencrypted_metadata)?)
.await?;
file.save().await.with_kind(ErrorKind::Filesystem)?;
Ok(())
}
#[instrument(skip_all)]
pub async fn unmount(mut self) -> Result<(), Error> {
if let Some(guard) = self.encrypted_guard.take() {
guard.unmount().await?;
}
if let Some(guard) = self.backup_disk_mount_guard.take() {
guard.unmount().await?;
}
Ok(())
}
#[instrument(skip_all)]
pub async fn save_and_unmount(self) -> Result<(), Error> {
self.save().await?;
self.unmount().await?;
Ok(())
}
}
impl<G: GenericMountGuard> AsRef<Path> for BackupMountGuard<G> {
fn as_ref(&self) -> &Path {
if let Some(guard) = &self.encrypted_guard {
guard.as_ref()
} else {
unreachable!()
}
}
}
impl<G: GenericMountGuard> Drop for BackupMountGuard<G> {
fn drop(&mut self) {
let first = self.encrypted_guard.take();
let second = self.backup_disk_mount_guard.take();
tokio::spawn(async move {
if let Some(guard) = first {
guard.unmount().await.unwrap();
}
if let Some(guard) = second {
guard.unmount().await.unwrap();
}
});
}
}
pub struct PackageBackupMountGuard {
mountpoint: Option<PathBuf>,
lock: Option<FileLock>,
}
impl PackageBackupMountGuard {
pub async fn unmount(mut self) -> Result<(), Error> {
if let Some(mountpoint) = self.mountpoint.take() {
unmount(&mountpoint).await?;
}
if let Some(lock) = self.lock.take() {
lock.unlock().await?;
}
Ok(())
}
}
impl AsRef<Path> for PackageBackupMountGuard {
fn as_ref(&self) -> &Path {
if let Some(mountpoint) = &self.mountpoint {
mountpoint
} else {
unreachable!()
}
}
}
impl Drop for PackageBackupMountGuard {
fn drop(&mut self) {
let mountpoint = self.mountpoint.take();
let lock = self.lock.take();
tokio::spawn(async move {
if let Some(mountpoint) = mountpoint {
unmount(&mountpoint).await.unwrap();
}
if let Some(lock) = lock {
lock.unlock().await.unwrap();
}
});
}
}

View File

@@ -0,0 +1,54 @@
use std::os::unix::ffi::OsStrExt;
use std::path::Path;
use async_trait::async_trait;
use digest::generic_array::GenericArray;
use digest::{Digest, OutputSizeUser};
use sha2::Sha256;
use super::{FileSystem, MountType, ReadOnly};
use crate::disk::mount::util::bind;
use crate::{Error, ResultExt};
pub struct Bind<SrcDir: AsRef<Path>> {
src_dir: SrcDir,
}
impl<SrcDir: AsRef<Path>> Bind<SrcDir> {
pub fn new(src_dir: SrcDir) -> Self {
Self { src_dir }
}
}
#[async_trait]
impl<SrcDir: AsRef<Path> + Send + Sync> FileSystem for Bind<SrcDir> {
async fn mount<P: AsRef<Path> + Send + Sync>(
&self,
mountpoint: P,
mount_type: MountType,
) -> Result<(), Error> {
bind(
self.src_dir.as_ref(),
mountpoint,
matches!(mount_type, ReadOnly),
)
.await
}
async fn source_hash(
&self,
) -> Result<GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize>, Error> {
let mut sha = Sha256::new();
sha.update("Bind");
sha.update(
tokio::fs::canonicalize(self.src_dir.as_ref())
.await
.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
self.src_dir.as_ref().display().to_string(),
)
})?
.as_os_str()
.as_bytes(),
);
Ok(sha.finalize())
}
}

View File

@@ -0,0 +1,67 @@
use std::os::unix::ffi::OsStrExt;
use std::path::Path;
use async_trait::async_trait;
use digest::generic_array::GenericArray;
use digest::{Digest, OutputSizeUser};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use super::{FileSystem, MountType, ReadOnly};
use crate::util::Invoke;
use crate::{Error, ResultExt};
pub async fn mount(
logicalname: impl AsRef<Path>,
mountpoint: impl AsRef<Path>,
mount_type: MountType,
) -> Result<(), Error> {
tokio::fs::create_dir_all(mountpoint.as_ref()).await?;
let mut cmd = tokio::process::Command::new("mount");
cmd.arg(logicalname.as_ref()).arg(mountpoint.as_ref());
if mount_type == ReadOnly {
cmd.arg("-o").arg("ro");
}
cmd.invoke(crate::ErrorKind::Filesystem).await?;
Ok(())
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct BlockDev<LogicalName: AsRef<Path>> {
logicalname: LogicalName,
}
impl<LogicalName: AsRef<Path>> BlockDev<LogicalName> {
pub fn new(logicalname: LogicalName) -> Self {
BlockDev { logicalname }
}
}
#[async_trait]
impl<LogicalName: AsRef<Path> + Send + Sync> FileSystem for BlockDev<LogicalName> {
async fn mount<P: AsRef<Path> + Send + Sync>(
&self,
mountpoint: P,
mount_type: MountType,
) -> Result<(), Error> {
mount(self.logicalname.as_ref(), mountpoint, mount_type).await
}
async fn source_hash(
&self,
) -> Result<GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize>, Error> {
let mut sha = Sha256::new();
sha.update("BlockDev");
sha.update(
tokio::fs::canonicalize(self.logicalname.as_ref())
.await
.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
self.logicalname.as_ref().display().to_string(),
)
})?
.as_os_str()
.as_bytes(),
);
Ok(sha.finalize())
}
}

View File

@@ -0,0 +1,107 @@
use std::net::IpAddr;
use std::os::unix::ffi::OsStrExt;
use std::path::{Path, PathBuf};
use async_trait::async_trait;
use digest::generic_array::GenericArray;
use digest::{Digest, OutputSizeUser};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use tokio::process::Command;
use tracing::instrument;
use super::{FileSystem, MountType, ReadOnly};
use crate::disk::mount::guard::TmpMountGuard;
use crate::util::Invoke;
use crate::Error;
async fn resolve_hostname(hostname: &str) -> Result<IpAddr, Error> {
if let Ok(addr) = hostname.parse() {
return Ok(addr);
}
if hostname.ends_with(".local") {
return Ok(IpAddr::V4(crate::net::mdns::resolve_mdns(hostname).await?));
}
Ok(String::from_utf8(
Command::new("nmblookup")
.arg(hostname)
.invoke(crate::ErrorKind::Network)
.await?,
)?
.split(" ")
.next()
.unwrap()
.trim()
.parse()?)
}
#[instrument(skip_all)]
pub async fn mount_cifs(
hostname: &str,
path: impl AsRef<Path>,
username: &str,
password: Option<&str>,
mountpoint: impl AsRef<Path>,
mount_type: MountType,
) -> Result<(), Error> {
tokio::fs::create_dir_all(mountpoint.as_ref()).await?;
let ip: IpAddr = resolve_hostname(hostname).await?;
let absolute_path = Path::new("/").join(path.as_ref());
let mut cmd = Command::new("mount");
cmd.arg("-t")
.arg("cifs")
.env("USER", username)
.env("PASSWD", password.unwrap_or_default())
.arg(format!("//{}{}", ip, absolute_path.display()))
.arg(mountpoint.as_ref());
if mount_type == ReadOnly {
cmd.arg("-o").arg("ro,noserverino");
} else {
cmd.arg("-o").arg("noserverino");
}
cmd.invoke(crate::ErrorKind::Filesystem).await?;
Ok(())
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct Cifs {
pub hostname: String,
pub path: PathBuf,
pub username: String,
pub password: Option<String>,
}
impl Cifs {
pub async fn mountable(&self) -> Result<(), Error> {
let guard = TmpMountGuard::mount(self, ReadOnly).await?;
guard.unmount().await?;
Ok(())
}
}
#[async_trait]
impl FileSystem for Cifs {
async fn mount<P: AsRef<std::path::Path> + Send + Sync>(
&self,
mountpoint: P,
mount_type: MountType,
) -> Result<(), Error> {
mount_cifs(
&self.hostname,
&self.path,
&self.username,
self.password.as_ref().map(|p| p.as_str()),
mountpoint,
mount_type,
)
.await
}
async fn source_hash(
&self,
) -> Result<GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize>, Error> {
let mut sha = Sha256::new();
sha.update("Cifs");
sha.update(self.hostname.as_bytes());
sha.update(self.path.as_os_str().as_bytes());
Ok(sha.finalize())
}
}

View File

@@ -0,0 +1,71 @@
use std::os::unix::ffi::OsStrExt;
use std::path::Path;
use async_trait::async_trait;
use digest::generic_array::GenericArray;
use digest::{Digest, OutputSizeUser};
use sha2::Sha256;
use super::{FileSystem, MountType};
use crate::util::Invoke;
use crate::{Error, ResultExt};
pub async fn mount_ecryptfs<P0: AsRef<Path>, P1: AsRef<Path>>(
src: P0,
dst: P1,
key: &str,
) -> Result<(), Error> {
tokio::fs::create_dir_all(dst.as_ref()).await?;
tokio::process::Command::new("mount")
.arg("-t")
.arg("ecryptfs")
.arg(src.as_ref())
.arg(dst.as_ref())
.arg("-o")
// for more information `man ecryptfs`
.arg(format!("key=passphrase:passphrase_passwd={},ecryptfs_cipher=aes,ecryptfs_key_bytes=32,ecryptfs_passthrough=n,ecryptfs_enable_filename_crypto=y,no_sig_cache", key))
.input(Some(&mut std::io::Cursor::new(b"\n")))
.invoke(crate::ErrorKind::Filesystem).await?;
Ok(())
}
pub struct EcryptFS<EncryptedDir: AsRef<Path>, Key: AsRef<str>> {
encrypted_dir: EncryptedDir,
key: Key,
}
impl<EncryptedDir: AsRef<Path>, Key: AsRef<str>> EcryptFS<EncryptedDir, Key> {
pub fn new(encrypted_dir: EncryptedDir, key: Key) -> Self {
EcryptFS { encrypted_dir, key }
}
}
#[async_trait]
impl<EncryptedDir: AsRef<Path> + Send + Sync, Key: AsRef<str> + Send + Sync> FileSystem
for EcryptFS<EncryptedDir, Key>
{
async fn mount<P: AsRef<Path> + Send + Sync>(
&self,
mountpoint: P,
_mount_type: MountType, // ignored - inherited from parent fs
) -> Result<(), Error> {
mount_ecryptfs(self.encrypted_dir.as_ref(), mountpoint, self.key.as_ref()).await
}
async fn source_hash(
&self,
) -> Result<GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize>, Error> {
let mut sha = Sha256::new();
sha.update("EcryptFS");
sha.update(
tokio::fs::canonicalize(self.encrypted_dir.as_ref())
.await
.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
self.encrypted_dir.as_ref().display().to_string(),
)
})?
.as_os_str()
.as_bytes(),
);
Ok(sha.finalize())
}
}

View File

@@ -0,0 +1,39 @@
use std::path::Path;
use async_trait::async_trait;
use digest::generic_array::GenericArray;
use digest::{Digest, OutputSizeUser};
use sha2::Sha256;
use super::{FileSystem, MountType, ReadOnly};
use crate::util::Invoke;
use crate::Error;
pub struct EfiVarFs;
#[async_trait]
impl FileSystem for EfiVarFs {
async fn mount<P: AsRef<Path> + Send + Sync>(
&self,
mountpoint: P,
mount_type: MountType,
) -> Result<(), Error> {
tokio::fs::create_dir_all(mountpoint.as_ref()).await?;
let mut cmd = tokio::process::Command::new("mount");
cmd.arg("-t")
.arg("efivarfs")
.arg("efivarfs")
.arg(mountpoint.as_ref());
if mount_type == ReadOnly {
cmd.arg("-o").arg("ro");
}
cmd.invoke(crate::ErrorKind::Filesystem).await?;
Ok(())
}
async fn source_hash(
&self,
) -> Result<GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize>, Error> {
let mut sha = Sha256::new();
sha.update("EfiVarFs");
Ok(sha.finalize())
}
}

View File

@@ -0,0 +1,52 @@
use std::path::Path;
use async_trait::async_trait;
use digest::generic_array::GenericArray;
use digest::{Digest, OutputSizeUser};
use reqwest::Url;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use super::{FileSystem, MountType};
use crate::util::Invoke;
use crate::Error;
pub async fn mount_httpdirfs(url: &Url, mountpoint: impl AsRef<Path>) -> Result<(), Error> {
tokio::fs::create_dir_all(mountpoint.as_ref()).await?;
let mut cmd = tokio::process::Command::new("httpdirfs");
cmd.arg("--cache")
.arg("--single-file-mode")
.arg(url.as_str())
.arg(mountpoint.as_ref());
cmd.invoke(crate::ErrorKind::Filesystem).await?;
Ok(())
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct HttpDirFS {
url: Url,
}
impl HttpDirFS {
pub fn new(url: Url) -> Self {
HttpDirFS { url }
}
}
#[async_trait]
impl FileSystem for HttpDirFS {
async fn mount<P: AsRef<Path> + Send + Sync>(
&self,
mountpoint: P,
_mount_type: MountType,
) -> Result<(), Error> {
mount_httpdirfs(&self.url, mountpoint).await
}
async fn source_hash(
&self,
) -> Result<GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize>, Error> {
let mut sha = Sha256::new();
sha.update("HttpDirFS");
sha.update(self.url.as_str());
Ok(sha.finalize())
}
}

View File

@@ -0,0 +1,52 @@
use std::path::Path;
use async_trait::async_trait;
use digest::generic_array::GenericArray;
use digest::{Digest, OutputSizeUser};
use sha2::Sha256;
use super::{FileSystem, MountType, ReadOnly};
use crate::util::Invoke;
use crate::Error;
pub async fn mount_label(
label: &str,
mountpoint: impl AsRef<Path>,
mount_type: MountType,
) -> Result<(), Error> {
tokio::fs::create_dir_all(mountpoint.as_ref()).await?;
let mut cmd = tokio::process::Command::new("mount");
cmd.arg("-L").arg(label).arg(mountpoint.as_ref());
if mount_type == ReadOnly {
cmd.arg("-o").arg("ro");
}
cmd.invoke(crate::ErrorKind::Filesystem).await?;
Ok(())
}
pub struct Label<S: AsRef<str>> {
label: S,
}
impl<S: AsRef<str>> Label<S> {
pub fn new(label: S) -> Self {
Label { label }
}
}
#[async_trait]
impl<S: AsRef<str> + Send + Sync> FileSystem for Label<S> {
async fn mount<P: AsRef<Path> + Send + Sync>(
&self,
mountpoint: P,
mount_type: MountType,
) -> Result<(), Error> {
mount_label(self.label.as_ref(), mountpoint, mount_type).await
}
async fn source_hash(
&self,
) -> Result<GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize>, Error> {
let mut sha = Sha256::new();
sha.update("Label");
sha.update(self.label.as_ref().as_bytes());
Ok(sha.finalize())
}
}

View File

@@ -0,0 +1,89 @@
use std::os::unix::ffi::OsStrExt;
use std::path::Path;
use async_trait::async_trait;
use digest::generic_array::GenericArray;
use digest::{Digest, OutputSizeUser};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use super::{FileSystem, MountType, ReadOnly};
use crate::util::Invoke;
use crate::{Error, ResultExt};
pub async fn mount(
logicalname: impl AsRef<Path>,
offset: u64,
size: u64,
mountpoint: impl AsRef<Path>,
mount_type: MountType,
) -> Result<(), Error> {
tokio::fs::create_dir_all(mountpoint.as_ref()).await?;
let mut opts = format!("loop,offset={offset},sizelimit={size}");
if mount_type == ReadOnly {
opts += ",ro";
}
tokio::process::Command::new("mount")
.arg(logicalname.as_ref())
.arg(mountpoint.as_ref())
.arg("-o")
.arg(opts)
.invoke(crate::ErrorKind::Filesystem)
.await?;
Ok(())
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct LoopDev<LogicalName: AsRef<Path>> {
logicalname: LogicalName,
offset: u64,
size: u64,
}
impl<LogicalName: AsRef<Path>> LoopDev<LogicalName> {
pub fn new(logicalname: LogicalName, offset: u64, size: u64) -> Self {
Self {
logicalname,
offset,
size,
}
}
}
#[async_trait]
impl<LogicalName: AsRef<Path> + Send + Sync> FileSystem for LoopDev<LogicalName> {
async fn mount<P: AsRef<Path> + Send + Sync>(
&self,
mountpoint: P,
mount_type: MountType,
) -> Result<(), Error> {
mount(
self.logicalname.as_ref(),
self.offset,
self.size,
mountpoint,
mount_type,
)
.await
}
async fn source_hash(
&self,
) -> Result<GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize>, Error> {
let mut sha = Sha256::new();
sha.update("LoopDev");
sha.update(
tokio::fs::canonicalize(self.logicalname.as_ref())
.await
.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
self.logicalname.as_ref().display().to_string(),
)
})?
.as_os_str()
.as_bytes(),
);
sha.update(&u64::to_be_bytes(self.offset)[..]);
Ok(sha.finalize())
}
}

View File

@@ -0,0 +1,37 @@
use std::path::Path;
use async_trait::async_trait;
use digest::generic_array::GenericArray;
use digest::OutputSizeUser;
use sha2::Sha256;
use crate::Error;
pub mod bind;
pub mod block_dev;
pub mod cifs;
pub mod ecryptfs;
pub mod efivarfs;
pub mod httpdirfs;
pub mod label;
pub mod loop_dev;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MountType {
ReadOnly,
ReadWrite,
}
pub use MountType::*;
#[async_trait]
pub trait FileSystem {
async fn mount<P: AsRef<Path> + Send + Sync>(
&self,
mountpoint: P,
mount_type: MountType,
) -> Result<(), Error>;
async fn source_hash(
&self,
) -> Result<GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize>, Error>;
}

View File

@@ -0,0 +1,142 @@
use std::collections::BTreeMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Weak};
use lazy_static::lazy_static;
use models::ResultExt;
use tokio::sync::Mutex;
use tracing::instrument;
use super::filesystem::{FileSystem, MountType, ReadOnly, ReadWrite};
use super::util::unmount;
use crate::util::Invoke;
use crate::Error;
pub const TMP_MOUNTPOINT: &'static str = "/media/embassy/tmp";
#[async_trait::async_trait]
pub trait GenericMountGuard: AsRef<Path> + std::fmt::Debug + Send + Sync + 'static {
async fn unmount(mut self) -> Result<(), Error>;
}
#[derive(Debug)]
pub struct MountGuard {
mountpoint: PathBuf,
mounted: bool,
}
impl MountGuard {
pub async fn mount(
filesystem: &impl FileSystem,
mountpoint: impl AsRef<Path>,
mount_type: MountType,
) -> Result<Self, Error> {
let mountpoint = mountpoint.as_ref().to_owned();
filesystem.mount(&mountpoint, mount_type).await?;
Ok(MountGuard {
mountpoint,
mounted: true,
})
}
pub async fn unmount(mut self, delete_mountpoint: bool) -> Result<(), Error> {
if self.mounted {
unmount(&self.mountpoint).await?;
if delete_mountpoint {
match tokio::fs::remove_dir(&self.mountpoint).await {
Err(e) if e.raw_os_error() == Some(39) => Ok(()), // directory not empty
a => a,
}
.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
format!("rm {}", self.mountpoint.display()),
)
})?;
}
self.mounted = false;
}
Ok(())
}
}
impl AsRef<Path> for MountGuard {
fn as_ref(&self) -> &Path {
&self.mountpoint
}
}
impl Drop for MountGuard {
fn drop(&mut self) {
if self.mounted {
let mountpoint = std::mem::take(&mut self.mountpoint);
tokio::spawn(async move { unmount(mountpoint).await.unwrap() });
}
}
}
#[async_trait::async_trait]
impl GenericMountGuard for MountGuard {
async fn unmount(mut self) -> Result<(), Error> {
MountGuard::unmount(self, false).await
}
}
async fn tmp_mountpoint(source: &impl FileSystem) -> Result<PathBuf, Error> {
Ok(Path::new(TMP_MOUNTPOINT).join(base32::encode(
base32::Alphabet::RFC4648 { padding: false },
&source.source_hash().await?,
)))
}
lazy_static! {
static ref TMP_MOUNTS: Mutex<BTreeMap<PathBuf, (MountType, Weak<MountGuard>)>> =
Mutex::new(BTreeMap::new());
}
#[derive(Debug)]
pub struct TmpMountGuard {
guard: Arc<MountGuard>,
}
impl TmpMountGuard {
/// DRAGONS: if you try to mount something as ro and rw at the same time, the ro mount will be upgraded to rw.
#[instrument(skip_all)]
pub async fn mount(filesystem: &impl FileSystem, mount_type: MountType) -> Result<Self, Error> {
let mountpoint = tmp_mountpoint(filesystem).await?;
let mut tmp_mounts = TMP_MOUNTS.lock().await;
if !tmp_mounts.contains_key(&mountpoint) {
tmp_mounts.insert(mountpoint.clone(), (mount_type, Weak::new()));
}
let (prev_mt, weak_slot) = tmp_mounts.get_mut(&mountpoint).unwrap();
if let Some(guard) = weak_slot.upgrade() {
// upgrade to rw
if *prev_mt == ReadOnly && mount_type == ReadWrite {
tokio::process::Command::new("mount")
.arg("-o")
.arg("remount,rw")
.arg(&mountpoint)
.invoke(crate::ErrorKind::Filesystem)
.await?;
*prev_mt = ReadWrite;
}
Ok(TmpMountGuard { guard })
} else {
let guard = Arc::new(MountGuard::mount(filesystem, &mountpoint, mount_type).await?);
*weak_slot = Arc::downgrade(&guard);
*prev_mt = mount_type;
Ok(TmpMountGuard { guard })
}
}
pub async fn unmount(self) -> Result<(), Error> {
if let Ok(guard) = Arc::try_unwrap(self.guard) {
guard.unmount(true).await?;
}
Ok(())
}
}
impl AsRef<Path> for TmpMountGuard {
fn as_ref(&self) -> &Path {
(&*self.guard).as_ref()
}
}
#[async_trait::async_trait]
impl GenericMountGuard for TmpMountGuard {
async fn unmount(mut self) -> Result<(), Error> {
TmpMountGuard::unmount(self).await
}
}

View File

@@ -0,0 +1,4 @@
pub mod backup;
pub mod filesystem;
pub mod guard;
pub mod util;

View File

@@ -0,0 +1,52 @@
use std::path::Path;
use tracing::instrument;
use crate::util::Invoke;
use crate::Error;
#[instrument(skip_all)]
pub async fn bind<P0: AsRef<Path>, P1: AsRef<Path>>(
src: P0,
dst: P1,
read_only: bool,
) -> Result<(), Error> {
tracing::info!(
"Binding {} to {}",
src.as_ref().display(),
dst.as_ref().display()
);
let is_mountpoint = tokio::process::Command::new("mountpoint")
.arg(dst.as_ref())
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.status()
.await?;
if is_mountpoint.success() {
unmount(dst.as_ref()).await?;
}
tokio::fs::create_dir_all(&src).await?;
tokio::fs::create_dir_all(&dst).await?;
let mut mount_cmd = tokio::process::Command::new("mount");
mount_cmd.arg("--bind");
if read_only {
mount_cmd.arg("-o").arg("ro");
}
mount_cmd
.arg(src.as_ref())
.arg(dst.as_ref())
.invoke(crate::ErrorKind::Filesystem)
.await?;
Ok(())
}
#[instrument(skip_all)]
pub async fn unmount<P: AsRef<Path>>(mountpoint: P) -> Result<(), Error> {
tracing::debug!("Unmounting {}.", mountpoint.as_ref().display());
tokio::process::Command::new("umount")
.arg("-l")
.arg(mountpoint.as_ref())
.invoke(crate::ErrorKind::Filesystem)
.await?;
Ok(())
}

View File

@@ -0,0 +1,489 @@
use std::collections::{BTreeMap, BTreeSet};
use std::path::{Path, PathBuf};
use color_eyre::eyre::{self, eyre};
use futures::TryStreamExt;
use nom::bytes::complete::{tag, take_till1};
use nom::character::complete::multispace1;
use nom::character::is_space;
use nom::combinator::{opt, rest};
use nom::sequence::{pair, preceded, terminated};
use nom::IResult;
use regex::Regex;
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use tracing::instrument;
use super::mount::filesystem::block_dev::BlockDev;
use super::mount::filesystem::ReadOnly;
use super::mount::guard::TmpMountGuard;
use crate::disk::OsPartitionInfo;
use crate::util::serde::IoFormat;
use crate::util::{Invoke, Version};
use crate::{Error, ResultExt as _};
#[derive(Clone, Copy, Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub enum PartitionTable {
Mbr,
Gpt,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct DiskInfo {
pub logicalname: PathBuf,
pub partition_table: Option<PartitionTable>,
pub vendor: Option<String>,
pub model: Option<String>,
pub partitions: Vec<PartitionInfo>,
pub capacity: u64,
pub guid: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct PartitionInfo {
pub logicalname: PathBuf,
pub label: Option<String>,
pub capacity: u64,
pub used: Option<u64>,
pub embassy_os: Option<EmbassyOsRecoveryInfo>,
pub guid: Option<String>,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct EmbassyOsRecoveryInfo {
pub version: Version,
pub full: bool,
pub password_hash: Option<String>,
pub wrapped_key: Option<String>,
}
const DISK_PATH: &str = "/dev/disk/by-path";
const SYS_BLOCK_PATH: &str = "/sys/block";
lazy_static::lazy_static! {
static ref PARTITION_REGEX: Regex = Regex::new("-part[0-9]+$").unwrap();
}
#[instrument(skip_all)]
pub async fn get_partition_table<P: AsRef<Path>>(path: P) -> Result<Option<PartitionTable>, Error> {
Ok(String::from_utf8(
Command::new("fdisk")
.arg("-l")
.arg(path.as_ref())
.invoke(crate::ErrorKind::BlockDevice)
.await?,
)?
.lines()
.find_map(|l| l.strip_prefix("Disklabel type:"))
.and_then(|t| match t.trim() {
"dos" => Some(PartitionTable::Mbr),
"gpt" => Some(PartitionTable::Gpt),
_ => None,
}))
}
#[instrument(skip_all)]
pub async fn get_vendor<P: AsRef<Path>>(path: P) -> Result<Option<String>, Error> {
let vendor = tokio::fs::read_to_string(
Path::new(SYS_BLOCK_PATH)
.join(path.as_ref().strip_prefix("/dev").map_err(|_| {
Error::new(
eyre!("not a canonical block device"),
crate::ErrorKind::BlockDevice,
)
})?)
.join("device")
.join("vendor"),
)
.await?
.trim()
.to_owned();
Ok(if vendor.is_empty() {
None
} else {
Some(vendor)
})
}
#[instrument(skip_all)]
pub async fn get_model<P: AsRef<Path>>(path: P) -> Result<Option<String>, Error> {
let model = tokio::fs::read_to_string(
Path::new(SYS_BLOCK_PATH)
.join(path.as_ref().strip_prefix("/dev").map_err(|_| {
Error::new(
eyre!("not a canonical block device"),
crate::ErrorKind::BlockDevice,
)
})?)
.join("device")
.join("model"),
)
.await?
.trim()
.to_owned();
Ok(if model.is_empty() { None } else { Some(model) })
}
#[instrument(skip_all)]
pub async fn get_capacity<P: AsRef<Path>>(path: P) -> Result<u64, Error> {
Ok(String::from_utf8(
Command::new("blockdev")
.arg("--getsize64")
.arg(path.as_ref())
.invoke(crate::ErrorKind::BlockDevice)
.await?,
)?
.trim()
.parse::<u64>()?)
}
#[instrument(skip_all)]
pub async fn get_label<P: AsRef<Path>>(path: P) -> Result<Option<String>, Error> {
let label = String::from_utf8(
Command::new("lsblk")
.arg("-no")
.arg("label")
.arg(path.as_ref())
.invoke(crate::ErrorKind::BlockDevice)
.await?,
)?
.trim()
.to_owned();
Ok(if label.is_empty() { None } else { Some(label) })
}
#[instrument(skip_all)]
pub async fn get_used<P: AsRef<Path>>(path: P) -> Result<u64, Error> {
Ok(String::from_utf8(
Command::new("df")
.arg("--output=used")
.arg("--block-size=1")
.arg(path.as_ref())
.invoke(crate::ErrorKind::Filesystem)
.await?,
)?
.lines()
.skip(1)
.next()
.unwrap_or_default()
.trim()
.parse::<u64>()?)
}
#[instrument(skip_all)]
pub async fn get_available<P: AsRef<Path>>(path: P) -> Result<u64, Error> {
Ok(String::from_utf8(
Command::new("df")
.arg("--output=avail")
.arg("--block-size=1")
.arg(path.as_ref())
.invoke(crate::ErrorKind::Filesystem)
.await?,
)?
.lines()
.skip(1)
.next()
.unwrap_or_default()
.trim()
.parse::<u64>()?)
}
#[instrument(skip_all)]
pub async fn get_percentage<P: AsRef<Path>>(path: P) -> Result<u64, Error> {
Ok(String::from_utf8(
Command::new("df")
.arg("--output=pcent")
.arg(path.as_ref())
.invoke(crate::ErrorKind::Filesystem)
.await?,
)?
.lines()
.skip(1)
.next()
.unwrap_or_default()
.trim()
.strip_suffix("%")
.unwrap()
.parse::<u64>()?)
}
#[instrument(skip_all)]
pub async fn pvscan() -> Result<BTreeMap<PathBuf, Option<String>>, Error> {
let pvscan_out = Command::new("pvscan")
.invoke(crate::ErrorKind::DiskManagement)
.await?;
let pvscan_out_str = std::str::from_utf8(&pvscan_out)?;
Ok(parse_pvscan_output(pvscan_out_str))
}
pub async fn recovery_info(
mountpoint: impl AsRef<Path>,
) -> Result<Option<EmbassyOsRecoveryInfo>, Error> {
let backup_unencrypted_metadata_path = mountpoint
.as_ref()
.join("EmbassyBackups/unencrypted-metadata.cbor");
if tokio::fs::metadata(&backup_unencrypted_metadata_path)
.await
.is_ok()
{
return Ok(Some(
IoFormat::Cbor.from_slice(
&tokio::fs::read(&backup_unencrypted_metadata_path)
.await
.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
backup_unencrypted_metadata_path.display().to_string(),
)
})?,
)?,
));
}
Ok(None)
}
#[instrument(skip_all)]
pub async fn list(os: &OsPartitionInfo) -> Result<Vec<DiskInfo>, Error> {
struct DiskIndex {
parts: BTreeSet<PathBuf>,
internal: bool,
}
let disk_guids = pvscan().await?;
let disks = tokio_stream::wrappers::ReadDirStream::new(
tokio::fs::read_dir(DISK_PATH)
.await
.with_ctx(|_| (crate::ErrorKind::Filesystem, DISK_PATH))?,
)
.map_err(|e| {
Error::new(
eyre::Error::from(e).wrap_err(DISK_PATH),
crate::ErrorKind::Filesystem,
)
})
.try_fold(
BTreeMap::<PathBuf, DiskIndex>::new(),
|mut disks, dir_entry| async move {
if let Some(disk_path) = dir_entry.path().file_name().and_then(|s| s.to_str()) {
let (disk_path, part_path) = if let Some(end) = PARTITION_REGEX.find(disk_path) {
(
disk_path.strip_suffix(end.as_str()).unwrap_or_default(),
Some(disk_path),
)
} else {
(disk_path, None)
};
let disk_path = Path::new(DISK_PATH).join(disk_path);
let disk = tokio::fs::canonicalize(&disk_path).await.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
disk_path.display().to_string(),
)
})?;
let part = if let Some(part_path) = part_path {
let part_path = Path::new(DISK_PATH).join(part_path);
let part = tokio::fs::canonicalize(&part_path).await.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
part_path.display().to_string(),
)
})?;
Some(part)
} else {
None
};
if !disks.contains_key(&disk) {
disks.insert(
disk.clone(),
DiskIndex {
parts: BTreeSet::new(),
internal: false,
},
);
}
if let Some(part) = part {
if os.contains(&part) {
disks.get_mut(&disk).unwrap().internal = true;
} else {
disks.get_mut(&disk).unwrap().parts.insert(part);
}
}
}
Ok(disks)
},
)
.await?;
let mut res = Vec::with_capacity(disks.len());
for (disk, index) in disks {
if index.internal {
for part in index.parts {
let mut disk_info = disk_info(disk.clone()).await;
let part_info = part_info(part).await;
disk_info.logicalname = part_info.logicalname.clone();
disk_info.capacity = part_info.capacity;
if let Some(g) = disk_guids.get(&disk_info.logicalname) {
disk_info.guid = g.clone();
} else {
disk_info.partitions = vec![part_info];
}
res.push(disk_info);
}
} else {
let mut disk_info = disk_info(disk).await;
disk_info.partitions = Vec::with_capacity(index.parts.len());
if let Some(g) = disk_guids.get(&disk_info.logicalname) {
disk_info.guid = g.clone();
} else {
for part in index.parts {
let mut part_info = part_info(part).await;
if let Some(g) = disk_guids.get(&part_info.logicalname) {
part_info.guid = g.clone();
}
disk_info.partitions.push(part_info);
}
}
res.push(disk_info);
}
}
Ok(res)
}
async fn disk_info(disk: PathBuf) -> DiskInfo {
let partition_table = get_partition_table(&disk)
.await
.map_err(|e| {
tracing::warn!(
"Could not get partition table of {}: {}",
disk.display(),
e.source
)
})
.unwrap_or_default();
let vendor = get_vendor(&disk)
.await
.map_err(|e| tracing::warn!("Could not get vendor of {}: {}", disk.display(), e.source))
.unwrap_or_default();
let model = get_model(&disk)
.await
.map_err(|e| tracing::warn!("Could not get model of {}: {}", disk.display(), e.source))
.unwrap_or_default();
let capacity = get_capacity(&disk)
.await
.map_err(|e| tracing::warn!("Could not get capacity of {}: {}", disk.display(), e.source))
.unwrap_or_default();
DiskInfo {
logicalname: disk,
partition_table,
vendor,
model,
partitions: Vec::new(),
capacity,
guid: None,
}
}
async fn part_info(part: PathBuf) -> PartitionInfo {
let mut embassy_os = None;
let label = get_label(&part)
.await
.map_err(|e| tracing::warn!("Could not get label of {}: {}", part.display(), e.source))
.unwrap_or_default();
let capacity = get_capacity(&part)
.await
.map_err(|e| tracing::warn!("Could not get capacity of {}: {}", part.display(), e.source))
.unwrap_or_default();
let mut used = None;
match TmpMountGuard::mount(&BlockDev::new(&part), ReadOnly).await {
Err(e) => tracing::warn!("Could not collect usage information: {}", e.source),
Ok(mount_guard) => {
used = get_used(&mount_guard)
.await
.map_err(|e| {
tracing::warn!("Could not get usage of {}: {}", part.display(), e.source)
})
.ok();
if let Some(recovery_info) = match recovery_info(&mount_guard).await {
Ok(a) => a,
Err(e) => {
tracing::error!("Error fetching unencrypted backup metadata: {}", e);
None
}
} {
embassy_os = Some(recovery_info)
}
if let Err(e) = mount_guard.unmount().await {
tracing::error!("Error unmounting partition {}: {}", part.display(), e);
}
}
}
PartitionInfo {
logicalname: part,
label,
capacity,
used,
embassy_os,
guid: None,
}
}
fn parse_pvscan_output(pvscan_output: &str) -> BTreeMap<PathBuf, Option<String>> {
fn parse_line(line: &str) -> IResult<&str, (&str, Option<&str>)> {
let pv_parse = preceded(
tag(" PV "),
terminated(take_till1(|c| is_space(c as u8)), multispace1),
);
let vg_parse = preceded(
opt(tag("is in exported ")),
preceded(
tag("VG "),
terminated(take_till1(|c| is_space(c as u8)), multispace1),
),
);
let mut parser = terminated(pair(pv_parse, opt(vg_parse)), rest);
parser(line)
}
let lines = pvscan_output.lines();
let n = lines.clone().count();
let entries = lines.take(n.saturating_sub(1));
let mut ret = BTreeMap::new();
for entry in entries {
match parse_line(entry) {
Ok((_, (pv, vg))) => {
ret.insert(PathBuf::from(pv), vg.map(|s| s.to_owned()));
}
Err(_) => {
tracing::warn!("Failed to parse pvscan output line: {}", entry);
}
}
}
ret
}
#[test]
fn test_pvscan_parser() {
let s1 = r#" PV /dev/mapper/cryptdata VG data lvm2 [1.81 TiB / 0 free]
PV /dev/sdb lvm2 [931.51 GiB]
Total: 2 [2.72 TiB] / in use: 1 [1.81 TiB] / in no VG: 1 [931.51 GiB]
"#;
let s2 = r#" PV /dev/sdb VG EMBASSY_LZHJAENWGPCJJL6C6AXOD7OOOIJG7HFBV4GYRJH6HADXUCN4BRWQ lvm2 [931.51 GiB / 0 free]
Total: 1 [931.51 GiB] / in use: 1 [931.51 GiB] / in no VG: 0 [0 ]
"#;
let s3 = r#" PV /dev/mapper/cryptdata VG data lvm2 [1.81 TiB / 0 free]
Total: 1 [1.81 TiB] / in use: 1 [1.81 TiB] / in no VG: 0 [0 ]
"#;
let s4 = r#" PV /dev/sda is in exported VG EMBASSY_ZFHOCTYV3ZJMJW3OTFMG55LSQZLP667EDNZKDNUJKPJX5HE6S5HQ [931.51 GiB / 0 free]
Total: 1 [931.51 GiB] / in use: 1 [931.51 GiB] / in no VG: 0 [0 ]
"#;
println!("{:?}", parse_pvscan_output(s1));
println!("{:?}", parse_pvscan_output(s2));
println!("{:?}", parse_pvscan_output(s3));
println!("{:?}", parse_pvscan_output(s4));
}

60
core/startos/src/error.rs Normal file
View File

@@ -0,0 +1,60 @@
use color_eyre::eyre::eyre;
pub use models::{Error, ErrorKind, OptionExt, ResultExt};
#[derive(Debug, Default)]
pub struct ErrorCollection(Vec<Error>);
impl ErrorCollection {
pub fn new() -> Self {
Self::default()
}
pub fn handle<T, E: Into<Error>>(&mut self, result: Result<T, E>) -> Option<T> {
match result {
Ok(a) => Some(a),
Err(e) => {
self.0.push(e.into());
None
}
}
}
pub fn into_result(self) -> Result<(), Error> {
if self.0.is_empty() {
Ok(())
} else {
Err(Error::new(eyre!("{}", self), ErrorKind::MultipleErrors))
}
}
}
impl From<ErrorCollection> for Result<(), Error> {
fn from(e: ErrorCollection) -> Self {
e.into_result()
}
}
impl<T, E: Into<Error>> Extend<Result<T, E>> for ErrorCollection {
fn extend<I: IntoIterator<Item = Result<T, E>>>(&mut self, iter: I) {
for item in iter {
self.handle(item);
}
}
}
impl std::fmt::Display for ErrorCollection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for (idx, e) in self.0.iter().enumerate() {
if idx > 0 {
write!(f, "; ")?;
}
write!(f, "{}", e)?;
}
Ok(())
}
}
#[macro_export]
macro_rules! ensure_code {
($x:expr, $c:expr, $fmt:expr $(, $arg:expr)*) => {
if !($x) {
return Err(crate::error::Error::new(color_eyre::eyre::eyre!($fmt, $($arg, )*), $c));
}
};
}

View File

@@ -0,0 +1,70 @@
use std::path::Path;
use async_compression::tokio::bufread::GzipDecoder;
use tokio::fs::File;
use tokio::io::{AsyncRead, BufReader};
use tokio::process::Command;
use crate::disk::fsck::RequiresReboot;
use crate::prelude::*;
use crate::util::Invoke;
pub async fn update_firmware() -> Result<RequiresReboot, Error> {
let product_name = String::from_utf8(
Command::new("dmidecode")
.arg("-s")
.arg("system-product-name")
.invoke(ErrorKind::Firmware)
.await?,
)?
.trim()
.to_owned();
if product_name.is_empty() {
return Ok(RequiresReboot(false));
}
let firmware_dir = Path::new("/usr/lib/startos/firmware").join(&product_name);
if tokio::fs::metadata(&firmware_dir).await.is_ok() {
let current_firmware = String::from_utf8(
Command::new("dmidecode")
.arg("-s")
.arg("bios-version")
.invoke(ErrorKind::Firmware)
.await?,
)?
.trim()
.to_owned();
if tokio::fs::metadata(firmware_dir.join(format!("{current_firmware}.rom.gz")))
.await
.is_err()
&& tokio::fs::metadata(firmware_dir.join(format!("{current_firmware}.rom")))
.await
.is_err()
{
let mut firmware_read_dir = tokio::fs::read_dir(&firmware_dir).await?;
while let Some(entry) = firmware_read_dir.next_entry().await? {
let filename = entry.file_name().to_string_lossy().into_owned();
let rdr: Option<Box<dyn AsyncRead + Unpin + Send>> =
if filename.ends_with(".rom.gz") {
Some(Box::new(GzipDecoder::new(BufReader::new(
File::open(entry.path()).await?,
))))
} else if filename.ends_with(".rom") {
Some(Box::new(File::open(entry.path()).await?))
} else {
None
};
if let Some(mut rdr) = rdr {
Command::new("flashrom")
.arg("-p")
.arg("internal")
.arg("-w-")
.input(Some(&mut rdr))
.invoke(ErrorKind::Firmware)
.await?;
return Ok(RequiresReboot(true));
}
}
}
}
Ok(RequiresReboot(false))
}

View File

@@ -0,0 +1,83 @@
use rand::{thread_rng, Rng};
use tokio::process::Command;
use tracing::instrument;
use crate::util::Invoke;
use crate::{Error, ErrorKind};
#[derive(Clone, serde::Deserialize, serde::Serialize, Debug)]
pub struct Hostname(pub String);
lazy_static::lazy_static! {
static ref ADJECTIVES: Vec<String> = include_str!("./assets/adjectives.txt").lines().map(|x| x.to_string()).collect();
static ref NOUNS: Vec<String> = include_str!("./assets/nouns.txt").lines().map(|x| x.to_string()).collect();
}
impl AsRef<str> for Hostname {
fn as_ref(&self) -> &str {
&self.0
}
}
impl Hostname {
pub fn lan_address(&self) -> String {
format!("https://{}.local", self.0)
}
pub fn local_domain_name(&self) -> String {
format!("{}.local", self.0)
}
pub fn no_dot_host_name(&self) -> String {
self.0.to_owned()
}
}
pub fn generate_hostname() -> Hostname {
let mut rng = thread_rng();
let adjective = &ADJECTIVES[rng.gen_range(0..ADJECTIVES.len())];
let noun = &NOUNS[rng.gen_range(0..NOUNS.len())];
Hostname(format!("{adjective}-{noun}"))
}
pub fn generate_id() -> String {
let id = uuid::Uuid::new_v4();
id.to_string()
}
#[instrument(skip_all)]
pub async fn get_current_hostname() -> Result<Hostname, Error> {
let out = Command::new("hostname")
.invoke(ErrorKind::ParseSysInfo)
.await?;
let out_string = String::from_utf8(out)?;
Ok(Hostname(out_string.trim().to_owned()))
}
#[instrument(skip_all)]
pub async fn set_hostname(hostname: &Hostname) -> Result<(), Error> {
let hostname: &String = &hostname.0;
Command::new("hostnamectl")
.arg("--static")
.arg("set-hostname")
.arg(hostname)
.invoke(ErrorKind::ParseSysInfo)
.await?;
Command::new("sed")
.arg("-i")
.arg(format!(
"s/\\(\\s\\)localhost\\( {hostname}\\)\\?/\\1localhost {hostname}/g"
))
.arg("/etc/hosts")
.invoke(ErrorKind::ParseSysInfo)
.await?;
Ok(())
}
#[instrument(skip_all)]
pub async fn sync_hostname(hostname: &Hostname) -> Result<(), Error> {
set_hostname(hostname).await?;
Command::new("systemctl")
.arg("restart")
.arg("avahi-daemon")
.invoke(crate::ErrorKind::Network)
.await?;
Ok(())
}

455
core/startos/src/init.rs Normal file
View File

@@ -0,0 +1,455 @@
use std::fs::Permissions;
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::time::{Duration, SystemTime};
use color_eyre::eyre::eyre;
use helpers::NonDetachingJoinHandle;
use models::ResultExt;
use rand::random;
use sqlx::{Pool, Postgres};
use tokio::process::Command;
use tracing::instrument;
use crate::account::AccountInfo;
use crate::context::rpc::RpcContextConfig;
use crate::db::model::ServerStatus;
use crate::disk::mount::util::unmount;
use crate::install::PKG_ARCHIVE_DIR;
use crate::middleware::auth::LOCAL_AUTH_COOKIE_PATH;
use crate::prelude::*;
use crate::sound::BEP;
use crate::util::cpupower::{
current_governor, get_available_governors, set_governor, GOVERNOR_PERFORMANCE,
};
use crate::util::docker::{create_bridge_network, CONTAINER_DATADIR, CONTAINER_TOOL};
use crate::util::Invoke;
use crate::{Error, ARCH};
pub const SYSTEM_REBUILD_PATH: &str = "/media/embassy/config/system-rebuild";
pub const STANDBY_MODE_PATH: &str = "/media/embassy/config/standby";
pub async fn check_time_is_synchronized() -> Result<bool, Error> {
Ok(String::from_utf8(
Command::new("timedatectl")
.arg("show")
.arg("-p")
.arg("NTPSynchronized")
.invoke(crate::ErrorKind::Unknown)
.await?,
)?
.trim()
== "NTPSynchronized=yes")
}
// must be idempotent
#[tracing::instrument(skip_all)]
pub async fn init_postgres(datadir: impl AsRef<Path>) -> Result<(), Error> {
let db_dir = datadir.as_ref().join("main/postgresql");
if tokio::process::Command::new("mountpoint")
.arg("/var/lib/postgresql")
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.status()
.await?
.success()
{
unmount("/var/lib/postgresql").await?;
}
let exists = tokio::fs::metadata(&db_dir).await.is_ok();
if !exists {
Command::new("cp")
.arg("-ra")
.arg("/var/lib/postgresql")
.arg(&db_dir)
.invoke(crate::ErrorKind::Filesystem)
.await?;
}
Command::new("chown")
.arg("-R")
.arg("postgres:postgres")
.arg(&db_dir)
.invoke(crate::ErrorKind::Database)
.await?;
let mut pg_paths = tokio::fs::read_dir("/usr/lib/postgresql").await?;
let mut pg_version = None;
while let Some(pg_path) = pg_paths.next_entry().await? {
let pg_path_version = pg_path
.file_name()
.to_str()
.map(|v| v.parse())
.transpose()?
.unwrap_or(0);
if pg_path_version > pg_version.unwrap_or(0) {
pg_version = Some(pg_path_version)
}
}
let pg_version = pg_version.ok_or_else(|| {
Error::new(
eyre!("could not determine postgresql version"),
crate::ErrorKind::Database,
)
})?;
crate::disk::mount::util::bind(&db_dir, "/var/lib/postgresql", false).await?;
let pg_version_string = pg_version.to_string();
let pg_version_path = db_dir.join(&pg_version_string);
if tokio::fs::metadata(&pg_version_path).await.is_err() {
let conf_dir = Path::new("/etc/postgresql").join(pg_version.to_string());
let conf_dir_tmp = {
let mut tmp = conf_dir.clone();
tmp.set_extension("tmp");
tmp
};
if tokio::fs::metadata(&conf_dir).await.is_ok() {
Command::new("mv")
.arg(&conf_dir)
.arg(&conf_dir_tmp)
.invoke(ErrorKind::Filesystem)
.await?;
}
let mut old_version = pg_version;
while old_version > 13
/* oldest pg version included in startos */
{
old_version -= 1;
let old_datadir = db_dir.join(old_version.to_string());
if tokio::fs::metadata(&old_datadir).await.is_ok() {
Command::new("pg_upgradecluster")
.arg(old_version.to_string())
.arg("main")
.invoke(crate::ErrorKind::Database)
.await?;
break;
}
}
if tokio::fs::metadata(&conf_dir).await.is_ok() {
if tokio::fs::metadata(&conf_dir).await.is_ok() {
tokio::fs::remove_dir_all(&conf_dir).await?;
}
Command::new("mv")
.arg(&conf_dir_tmp)
.arg(&conf_dir)
.invoke(ErrorKind::Filesystem)
.await?;
}
}
Command::new("systemctl")
.arg("start")
.arg(format!("postgresql@{pg_version}-main.service"))
.invoke(crate::ErrorKind::Database)
.await?;
if !exists {
Command::new("sudo")
.arg("-u")
.arg("postgres")
.arg("createuser")
.arg("root")
.invoke(crate::ErrorKind::Database)
.await?;
Command::new("sudo")
.arg("-u")
.arg("postgres")
.arg("createdb")
.arg("secrets")
.arg("-O")
.arg("root")
.invoke(crate::ErrorKind::Database)
.await?;
}
Ok(())
}
pub struct InitResult {
pub secret_store: Pool<Postgres>,
pub db: patch_db::PatchDb,
}
#[instrument(skip_all)]
pub async fn init(cfg: &RpcContextConfig) -> Result<InitResult, Error> {
tokio::fs::create_dir_all("/run/embassy")
.await
.with_ctx(|_| (crate::ErrorKind::Filesystem, "mkdir -p /run/embassy"))?;
if tokio::fs::metadata(LOCAL_AUTH_COOKIE_PATH).await.is_err() {
tokio::fs::write(
LOCAL_AUTH_COOKIE_PATH,
base64::encode(random::<[u8; 32]>()).as_bytes(),
)
.await
.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
format!("write {}", LOCAL_AUTH_COOKIE_PATH),
)
})?;
tokio::fs::set_permissions(LOCAL_AUTH_COOKIE_PATH, Permissions::from_mode(0o046)).await?;
Command::new("chown")
.arg("root:embassy")
.arg(LOCAL_AUTH_COOKIE_PATH)
.invoke(crate::ErrorKind::Filesystem)
.await?;
}
let secret_store = cfg.secret_store().await?;
tracing::info!("Opened Postgres");
crate::ssh::sync_keys_from_db(&secret_store, "/home/start9/.ssh/authorized_keys").await?;
tracing::info!("Synced SSH Keys");
let account = AccountInfo::load(&secret_store).await?;
let db = cfg.db(&account).await?;
tracing::info!("Opened PatchDB");
let peek = db.peek().await;
let mut server_info = peek.as_server_info().de()?;
// write to ca cert store
tokio::fs::write(
"/usr/local/share/ca-certificates/startos-root-ca.crt",
account.root_ca_cert.to_pem()?,
)
.await?;
Command::new("update-ca-certificates")
.invoke(crate::ErrorKind::OpenSsl)
.await?;
if let Some(wifi_interface) = &cfg.wifi_interface {
crate::net::wifi::synchronize_wpa_supplicant_conf(
&cfg.datadir().join("main"),
wifi_interface,
&server_info.last_wifi_region,
)
.await?;
tracing::info!("Synchronized WiFi");
}
let should_rebuild = tokio::fs::metadata(SYSTEM_REBUILD_PATH).await.is_ok()
|| &*server_info.version < &emver::Version::new(0, 3, 2, 0)
|| (*ARCH == "x86_64" && &*server_info.version < &emver::Version::new(0, 3, 4, 0));
let song = if should_rebuild {
Some(NonDetachingJoinHandle::from(tokio::spawn(async {
loop {
BEP.play().await.unwrap();
BEP.play().await.unwrap();
tokio::time::sleep(Duration::from_secs(60)).await;
}
})))
} else {
None
};
let log_dir = cfg.datadir().join("main/logs");
if tokio::fs::metadata(&log_dir).await.is_err() {
tokio::fs::create_dir_all(&log_dir).await?;
}
let current_machine_id = tokio::fs::read_to_string("/etc/machine-id").await?;
let mut machine_ids = tokio::fs::read_dir(&log_dir).await?;
while let Some(machine_id) = machine_ids.next_entry().await? {
if machine_id.file_name().to_string_lossy().trim() != current_machine_id.trim() {
tokio::fs::remove_dir_all(machine_id.path()).await?;
}
}
crate::disk::mount::util::bind(&log_dir, "/var/log/journal", false).await?;
match Command::new("chattr")
.arg("-R")
.arg("+C")
.arg("/var/log/journal")
.invoke(ErrorKind::Filesystem)
.await
{
Ok(_) => Ok(()),
Err(e) if e.source.to_string().contains("Operation not supported") => Ok(()),
Err(e) => Err(e),
}?;
Command::new("systemctl")
.arg("restart")
.arg("systemd-journald")
.invoke(crate::ErrorKind::Journald)
.await?;
tracing::info!("Mounted Logs");
let tmp_dir = cfg.datadir().join("package-data/tmp");
if should_rebuild && tokio::fs::metadata(&tmp_dir).await.is_ok() {
tokio::fs::remove_dir_all(&tmp_dir).await?;
}
if tokio::fs::metadata(&tmp_dir).await.is_err() {
tokio::fs::create_dir_all(&tmp_dir).await?;
}
let tmp_var = cfg.datadir().join(format!("package-data/tmp/var"));
if tokio::fs::metadata(&tmp_var).await.is_ok() {
tokio::fs::remove_dir_all(&tmp_var).await?;
}
crate::disk::mount::util::bind(&tmp_var, "/var/tmp", false).await?;
let tmp_docker = cfg
.datadir()
.join(format!("package-data/tmp/{CONTAINER_TOOL}"));
let tmp_docker_exists = tokio::fs::metadata(&tmp_docker).await.is_ok();
if CONTAINER_TOOL == "docker" {
Command::new("systemctl")
.arg("stop")
.arg("docker")
.invoke(crate::ErrorKind::Docker)
.await?;
}
crate::disk::mount::util::bind(&tmp_docker, CONTAINER_DATADIR, false).await?;
if CONTAINER_TOOL == "docker" {
Command::new("systemctl")
.arg("reset-failed")
.arg("docker")
.invoke(crate::ErrorKind::Docker)
.await?;
Command::new("systemctl")
.arg("start")
.arg("docker")
.invoke(crate::ErrorKind::Docker)
.await?;
}
tracing::info!("Mounted Docker Data");
if should_rebuild || !tmp_docker_exists {
if CONTAINER_TOOL == "docker" {
tracing::info!("Creating Docker Network");
create_bridge_network("start9", "172.18.0.1/24", "br-start9").await?;
tracing::info!("Created Docker Network");
}
tracing::info!("Loading System Docker Images");
crate::install::load_images("/usr/lib/startos/system-images").await?;
tracing::info!("Loaded System Docker Images");
tracing::info!("Loading Package Docker Images");
crate::install::load_images(cfg.datadir().join(PKG_ARCHIVE_DIR)).await?;
tracing::info!("Loaded Package Docker Images");
}
if CONTAINER_TOOL == "podman" {
crate::util::docker::remove_container("netdummy", true).await?;
Command::new("podman")
.arg("run")
.arg("-d")
.arg("--rm")
.arg("--network=start9")
.arg("--name=netdummy")
.arg("start9/x_system/utils:latest")
.arg("sleep")
.arg("infinity")
.invoke(crate::ErrorKind::Docker)
.await?;
}
tracing::info!("Enabling Docker QEMU Emulation");
Command::new(CONTAINER_TOOL)
.arg("run")
.arg("--privileged")
.arg("--rm")
.arg("start9/x_system/binfmt")
.arg("--install")
.arg("all")
.invoke(crate::ErrorKind::Docker)
.await?;
tracing::info!("Enabled Docker QEMU Emulation");
if current_governor()
.await?
.map(|g| &g != &GOVERNOR_PERFORMANCE)
.unwrap_or(false)
{
tracing::info!("Setting CPU Governor to \"{}\"", GOVERNOR_PERFORMANCE);
if get_available_governors()
.await?
.contains(&GOVERNOR_PERFORMANCE)
{
set_governor(&GOVERNOR_PERFORMANCE).await?;
tracing::info!("Set CPU Governor");
} else {
tracing::warn!("CPU Governor \"{}\" Not Available", GOVERNOR_PERFORMANCE)
}
}
let mut time_not_synced = true;
let mut not_made_progress = 0u32;
for _ in 0..1800 {
if check_time_is_synchronized().await? {
time_not_synced = false;
break;
}
let t = SystemTime::now();
tokio::time::sleep(Duration::from_secs(1)).await;
if t.elapsed()
.map(|t| t > Duration::from_secs_f64(1.1))
.unwrap_or(true)
{
not_made_progress = 0;
} else {
not_made_progress += 1;
}
if not_made_progress > 30 {
break;
}
}
if time_not_synced {
tracing::warn!("Timed out waiting for system time to synchronize");
} else {
tracing::info!("Syncronized system clock");
}
if server_info.zram {
crate::system::enable_zram().await?
}
server_info.ip_info = crate::net::dhcp::init_ips().await?;
server_info.status_info = ServerStatus {
updated: false,
update_progress: None,
backup_progress: None,
shutting_down: false,
restarting: false,
};
server_info.ntp_synced = if time_not_synced {
let db = db.clone();
tokio::spawn(async move {
while !check_time_is_synchronized().await.unwrap() {
tokio::time::sleep(Duration::from_secs(30)).await;
}
db.mutate(|v| v.as_server_info_mut().as_ntp_synced_mut().ser(&true))
.await
.unwrap()
});
false
} else {
true
};
db.mutate(|v| {
v.as_server_info_mut().ser(&server_info)?;
Ok(())
})
.await?;
crate::version::init(&db, &secret_store).await?;
db.mutate(|d| {
let model = d.de()?;
d.ser(&model)
})
.await?;
if should_rebuild {
match tokio::fs::remove_file(SYSTEM_REBUILD_PATH).await {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(e),
}?;
}
drop(song);
tracing::info!("System initialized.");
Ok(InitResult { secret_store, db })
}

View File

@@ -0,0 +1,92 @@
use std::path::PathBuf;
use rpc_toolkit::command;
use crate::s9pk::manifest::Manifest;
use crate::s9pk::reader::S9pkReader;
use crate::util::display_none;
use crate::util::serde::{display_serializable, IoFormat};
use crate::Error;
#[command(subcommands(hash, manifest, license, icon, instructions, docker_images))]
pub fn inspect() -> Result<(), Error> {
Ok(())
}
#[command(cli_only)]
pub async fn hash(#[arg] path: PathBuf) -> Result<String, Error> {
Ok(S9pkReader::open(path, true)
.await?
.hash_str()
.unwrap()
.to_owned())
}
#[command(cli_only, display(display_serializable))]
pub async fn manifest(
#[arg] path: PathBuf,
#[arg(rename = "no-verify", long = "no-verify")] no_verify: bool,
#[allow(unused_variables)]
#[arg(long = "format")]
format: Option<IoFormat>,
) -> Result<Manifest, Error> {
S9pkReader::open(path, !no_verify).await?.manifest().await
}
#[command(cli_only, display(display_none))]
pub async fn license(
#[arg] path: PathBuf,
#[arg(rename = "no-verify", long = "no-verify")] no_verify: bool,
) -> Result<(), Error> {
tokio::io::copy(
&mut S9pkReader::open(path, !no_verify).await?.license().await?,
&mut tokio::io::stdout(),
)
.await?;
Ok(())
}
#[command(cli_only, display(display_none))]
pub async fn icon(
#[arg] path: PathBuf,
#[arg(rename = "no-verify", long = "no-verify")] no_verify: bool,
) -> Result<(), Error> {
tokio::io::copy(
&mut S9pkReader::open(path, !no_verify).await?.icon().await?,
&mut tokio::io::stdout(),
)
.await?;
Ok(())
}
#[command(cli_only, display(display_none))]
pub async fn instructions(
#[arg] path: PathBuf,
#[arg(rename = "no-verify", long = "no-verify")] no_verify: bool,
) -> Result<(), Error> {
tokio::io::copy(
&mut S9pkReader::open(path, !no_verify)
.await?
.instructions()
.await?,
&mut tokio::io::stdout(),
)
.await?;
Ok(())
}
#[command(cli_only, display(display_none), rename = "docker-images")]
pub async fn docker_images(
#[arg] path: PathBuf,
#[arg(rename = "no-verify", long = "no-verify")] no_verify: bool,
) -> Result<(), Error> {
tokio::io::copy(
&mut S9pkReader::open(path, !no_verify)
.await?
.docker_images()
.await?,
&mut tokio::io::stdout(),
)
.await?;
Ok(())
}

View File

@@ -0,0 +1,241 @@
use std::path::PathBuf;
use std::sync::Arc;
use models::OptionExt;
use sqlx::{Executor, Postgres};
use tracing::instrument;
use super::PKG_ARCHIVE_DIR;
use crate::context::RpcContext;
use crate::db::model::{
CurrentDependencies, Database, PackageDataEntry, PackageDataEntryInstalled,
PackageDataEntryMatchModelRef,
};
use crate::error::ErrorCollection;
use crate::prelude::*;
use crate::s9pk::manifest::PackageId;
use crate::util::{Apply, Version};
use crate::volume::{asset_dir, script_dir};
use crate::Error;
#[instrument(skip_all)]
pub async fn cleanup(ctx: &RpcContext, id: &PackageId, version: &Version) -> Result<(), Error> {
let mut errors = ErrorCollection::new();
ctx.managers.remove(&(id.clone(), version.clone())).await;
// docker images start9/$APP_ID/*:$VERSION -q | xargs docker rmi
let images = crate::util::docker::images_for(id, version).await?;
errors.extend(
futures::future::join_all(images.into_iter().map(|sha| async {
let sha = sha; // move into future
crate::util::docker::remove_image(&sha).await
}))
.await,
);
let pkg_archive_dir = ctx
.datadir
.join(PKG_ARCHIVE_DIR)
.join(id)
.join(version.as_str());
if tokio::fs::metadata(&pkg_archive_dir).await.is_ok() {
tokio::fs::remove_dir_all(&pkg_archive_dir)
.await
.apply(|res| errors.handle(res));
}
let assets_path = asset_dir(&ctx.datadir, id, version);
if tokio::fs::metadata(&assets_path).await.is_ok() {
tokio::fs::remove_dir_all(&assets_path)
.await
.apply(|res| errors.handle(res));
}
let scripts_path = script_dir(&ctx.datadir, id, version);
if tokio::fs::metadata(&scripts_path).await.is_ok() {
tokio::fs::remove_dir_all(&scripts_path)
.await
.apply(|res| errors.handle(res));
}
errors.into_result()
}
#[instrument(skip_all)]
pub async fn cleanup_failed(ctx: &RpcContext, id: &PackageId) -> Result<(), Error> {
if let Some(version) = match ctx
.db
.peek()
.await
.as_package_data()
.as_idx(id)
.or_not_found(id)?
.as_match()
{
PackageDataEntryMatchModelRef::Installing(m) => Some(m.as_manifest().as_version().de()?),
PackageDataEntryMatchModelRef::Restoring(m) => Some(m.as_manifest().as_version().de()?),
PackageDataEntryMatchModelRef::Updating(m) => {
let manifest_version = m.as_manifest().as_version().de()?;
let installed = m.as_installed().as_manifest().as_version().de()?;
if manifest_version != installed {
Some(manifest_version)
} else {
None // do not remove existing data
}
}
_ => {
tracing::warn!("{}: Nothing to clean up!", id);
None
}
} {
cleanup(ctx, id, &version).await?;
}
ctx.db
.mutate(|v| {
match v
.clone()
.into_package_data()
.into_idx(id)
.or_not_found(id)?
.as_match()
{
PackageDataEntryMatchModelRef::Installing(_)
| PackageDataEntryMatchModelRef::Restoring(_) => {
v.as_package_data_mut().remove(id)?;
}
PackageDataEntryMatchModelRef::Updating(pde) => {
v.as_package_data_mut()
.as_idx_mut(id)
.or_not_found(id)?
.ser(&PackageDataEntry::Installed(PackageDataEntryInstalled {
manifest: pde.as_installed().as_manifest().de()?,
static_files: pde.as_static_files().de()?,
installed: pde.as_installed().de()?,
}))?;
}
_ => (),
}
Ok(())
})
.await
}
#[instrument(skip_all)]
pub fn remove_from_current_dependents_lists(
db: &mut Model<Database>,
id: &PackageId,
current_dependencies: &CurrentDependencies,
) -> Result<(), Error> {
for dep in current_dependencies.0.keys().chain(std::iter::once(id)) {
if let Some(current_dependents) = db
.as_package_data_mut()
.as_idx_mut(dep)
.and_then(|d| d.as_installed_mut())
.map(|i| i.as_current_dependents_mut())
{
current_dependents.remove(id)?;
}
}
Ok(())
}
#[instrument(skip_all)]
pub async fn uninstall<Ex>(ctx: &RpcContext, secrets: &mut Ex, id: &PackageId) -> Result<(), Error>
where
for<'a> &'a mut Ex: Executor<'a, Database = Postgres>,
{
let db = ctx.db.peek().await;
let entry = db
.as_package_data()
.as_idx(id)
.or_not_found(id)?
.expect_as_removing()?;
let dependents_paths: Vec<PathBuf> = entry
.as_removing()
.as_current_dependents()
.keys()?
.into_iter()
.filter(|x| x != id)
.flat_map(|x| db.as_package_data().as_idx(&x))
.flat_map(|x| x.as_installed())
.flat_map(|x| x.as_manifest().as_volumes().de())
.flat_map(|x| x.values().cloned().collect::<Vec<_>>())
.flat_map(|x| x.pointer_path(&ctx.datadir))
.collect();
let volume_dir = ctx
.datadir
.join(crate::volume::PKG_VOLUME_DIR)
.join(&*entry.as_manifest().as_id().de()?);
let version = entry.as_removing().as_manifest().as_version().de()?;
tracing::debug!(
"Cleaning up {:?} except for {:?}",
volume_dir,
dependents_paths
);
cleanup(ctx, id, &version).await?;
cleanup_folder(volume_dir, Arc::new(dependents_paths)).await;
remove_network_keys(secrets, id).await?;
ctx.db
.mutate(|d| {
d.as_package_data_mut().remove(id)?;
remove_from_current_dependents_lists(
d,
id,
&entry.as_removing().as_current_dependencies().de()?,
)
})
.await
}
#[instrument(skip_all)]
pub async fn remove_network_keys<Ex>(secrets: &mut Ex, id: &PackageId) -> Result<(), Error>
where
for<'a> &'a mut Ex: Executor<'a, Database = Postgres>,
{
sqlx::query!("DELETE FROM network_keys WHERE package = $1", &*id)
.execute(&mut *secrets)
.await?;
sqlx::query!("DELETE FROM tor WHERE package = $1", &*id)
.execute(&mut *secrets)
.await?;
Ok(())
}
/// Needed to remove, without removing the folders that are mounted in the other docker containers
pub fn cleanup_folder(
path: PathBuf,
dependents_volumes: Arc<Vec<PathBuf>>,
) -> futures::future::BoxFuture<'static, ()> {
Box::pin(async move {
let meta_data = match tokio::fs::metadata(&path).await {
Ok(a) => a,
Err(_e) => {
return;
}
};
if !meta_data.is_dir() {
tracing::error!("is_not dir, remove {:?}", path);
let _ = tokio::fs::remove_file(&path).await;
return;
}
if !dependents_volumes
.iter()
.any(|v| v.starts_with(&path) || v == &path)
{
tracing::error!("No parents, remove {:?}", path);
let _ = tokio::fs::remove_dir_all(&path).await;
return;
}
let mut read_dir = match tokio::fs::read_dir(&path).await {
Ok(a) => a,
Err(_e) => {
return;
}
};
tracing::error!("Parents, recurse {:?}", path);
while let Some(entry) = read_dir.next_entry().await.ok().flatten() {
let entry_path = entry.path();
cleanup_folder(entry_path, dependents_volumes.clone()).await;
}
})
}

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 280 KiB

View File

@@ -0,0 +1,228 @@
use std::future::Future;
use std::io::SeekFrom;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use models::{OptionExt, PackageId};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite};
use crate::db::model::Database;
use crate::prelude::*;
#[derive(Debug, Deserialize, Serialize, HasModel, Default)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct InstallProgress {
pub size: Option<u64>,
pub downloaded: AtomicU64,
pub download_complete: AtomicBool,
pub validated: AtomicU64,
pub validation_complete: AtomicBool,
pub unpacked: AtomicU64,
pub unpack_complete: AtomicBool,
}
impl InstallProgress {
pub fn new(size: Option<u64>) -> Self {
InstallProgress {
size,
downloaded: AtomicU64::new(0),
download_complete: AtomicBool::new(false),
validated: AtomicU64::new(0),
validation_complete: AtomicBool::new(false),
unpacked: AtomicU64::new(0),
unpack_complete: AtomicBool::new(false),
}
}
pub fn download_complete(&self) {
self.download_complete.store(true, Ordering::SeqCst)
}
pub async fn track_download(self: Arc<Self>, db: PatchDb, id: PackageId) -> Result<(), Error> {
let update = |d: &mut Model<Database>| {
d.as_package_data_mut()
.as_idx_mut(&id)
.or_not_found(&id)?
.as_install_progress_mut()
.or_not_found("install-progress")?
.ser(&self)
};
while !self.download_complete.load(Ordering::SeqCst) {
db.mutate(&update).await?;
tokio::time::sleep(Duration::from_millis(300)).await;
}
db.mutate(&update).await
}
pub async fn track_download_during<
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, Error>>,
T,
>(
self: &Arc<Self>,
db: PatchDb,
id: &PackageId,
f: F,
) -> Result<T, Error> {
let tracker = tokio::spawn(self.clone().track_download(db.clone(), id.clone()));
let res = f().await;
self.download_complete.store(true, Ordering::SeqCst);
tracker.await.unwrap()?;
res
}
pub async fn track_read(
self: Arc<Self>,
db: PatchDb,
id: PackageId,
complete: Arc<AtomicBool>,
) -> Result<(), Error> {
let update = |d: &mut Model<Database>| {
d.as_package_data_mut()
.as_idx_mut(&id)
.or_not_found(&id)?
.as_install_progress_mut()
.or_not_found("install-progress")?
.ser(&self)
};
while !complete.load(Ordering::SeqCst) {
db.mutate(&update).await?;
tokio::time::sleep(Duration::from_millis(300)).await;
}
db.mutate(&update).await
}
pub async fn track_read_during<
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, Error>>,
T,
>(
self: &Arc<Self>,
db: PatchDb,
id: &PackageId,
f: F,
) -> Result<T, Error> {
let complete = Arc::new(AtomicBool::new(false));
let tracker = tokio::spawn(self.clone().track_read(
db.clone(),
id.clone(),
complete.clone(),
));
let res = f().await;
complete.store(true, Ordering::SeqCst);
tracker.await.unwrap()?;
res
}
}
#[pin_project::pin_project]
#[derive(Debug)]
pub struct InstallProgressTracker<RW> {
#[pin]
inner: RW,
validating: bool,
progress: Arc<InstallProgress>,
}
impl<RW> InstallProgressTracker<RW> {
pub fn new(inner: RW, progress: Arc<InstallProgress>) -> Self {
InstallProgressTracker {
inner,
validating: true,
progress,
}
}
pub fn validated(&mut self) {
self.progress
.validation_complete
.store(true, Ordering::SeqCst);
self.validating = false;
}
}
impl<W: AsyncWrite> AsyncWrite for InstallProgressTracker<W> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.project();
match this.inner.poll_write(cx, buf) {
Poll::Ready(Ok(n)) => {
this.progress
.downloaded
.fetch_add(n as u64, Ordering::SeqCst);
Poll::Ready(Ok(n))
}
a => a,
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
let this = self.project();
this.inner.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.project();
this.inner.poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.project();
match this.inner.poll_write_vectored(cx, bufs) {
Poll::Ready(Ok(n)) => {
this.progress
.downloaded
.fetch_add(n as u64, Ordering::SeqCst);
Poll::Ready(Ok(n))
}
a => a,
}
}
}
impl<R: AsyncRead> AsyncRead for InstallProgressTracker<R> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.project();
let prev = buf.filled().len() as u64;
match this.inner.poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
if *this.validating {
&this.progress.validated
} else {
&this.progress.unpacked
}
.fetch_add(buf.filled().len() as u64 - prev, Ordering::SeqCst);
Poll::Ready(Ok(()))
}
a => a,
}
}
}
impl<R: AsyncSeek> AsyncSeek for InstallProgressTracker<R> {
fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> std::io::Result<()> {
let this = self.project();
this.inner.start_seek(position)
}
fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<u64>> {
let this = self.project();
match this.inner.poll_complete(cx) {
Poll::Ready(Ok(n)) => {
if *this.validating {
&this.progress.validated
} else {
&this.progress.unpacked
}
.store(n, Ordering::SeqCst);
Poll::Ready(Ok(n))
}
a => a,
}
}
}

View File

@@ -0,0 +1,18 @@
use std::collections::BTreeMap;
use rpc_toolkit::command;
use tracing::instrument;
use crate::config::not_found;
use crate::context::RpcContext;
use crate::db::model::CurrentDependents;
use crate::prelude::*;
use crate::s9pk::manifest::PackageId;
use crate::util::serde::display_serializable;
use crate::util::Version;
use crate::Error;
#[command(subcommands(dry))]
pub async fn update() -> Result<(), Error> {
Ok(())
}

159
core/startos/src/lib.rs Normal file
View File

@@ -0,0 +1,159 @@
#![recursion_limit = "256"]
pub const DEFAULT_MARKETPLACE: &str = "https://registry.start9.com";
// pub const COMMUNITY_MARKETPLACE: &str = "https://community-registry.start9.com";
pub const BUFFER_SIZE: usize = 1024;
pub const HOST_IP: [u8; 4] = [172, 18, 0, 1];
pub const TARGET: &str = current_platform::CURRENT_PLATFORM;
lazy_static::lazy_static! {
pub static ref ARCH: &'static str = {
let (arch, _) = TARGET.split_once("-").unwrap();
arch
};
pub static ref PLATFORM: String = {
if let Ok(platform) = std::fs::read_to_string("/usr/lib/startos/PLATFORM.txt") {
platform
} else {
ARCH.to_string()
}
};
pub static ref SOURCE_DATE: SystemTime = {
std::fs::metadata(std::env::current_exe().unwrap()).unwrap().modified().unwrap()
};
}
pub mod account;
pub mod action;
pub mod auth;
pub mod backup;
pub mod bins;
pub mod config;
pub mod context;
pub mod control;
pub mod core;
pub mod db;
pub mod dependencies;
pub mod developer;
pub mod diagnostic;
pub mod disk;
pub mod error;
pub mod firmware;
pub mod hostname;
pub mod init;
pub mod inspect;
pub mod install;
pub mod logs;
pub mod manager;
pub mod middleware;
pub mod migration;
pub mod net;
pub mod notifications;
pub mod os_install;
pub mod prelude;
pub mod procedure;
pub mod properties;
pub mod registry;
pub mod s9pk;
pub mod setup;
pub mod shutdown;
pub mod sound;
pub mod ssh;
pub mod status;
pub mod system;
pub mod update;
pub mod util;
pub mod version;
pub mod volume;
use std::time::SystemTime;
pub use config::Config;
pub use error::{Error, ErrorKind, ResultExt};
use rpc_toolkit::command;
use rpc_toolkit::yajrc::RpcError;
#[command(metadata(authenticated = false))]
pub fn echo(#[arg] message: String) -> Result<String, RpcError> {
Ok(message)
}
#[command(subcommands(
version::git_info,
echo,
inspect::inspect,
server,
package,
net::net,
auth::auth,
db::db,
ssh::ssh,
net::wifi::wifi,
disk::disk,
notifications::notification,
backup::backup,
registry::marketplace::marketplace,
))]
pub fn main_api() -> Result<(), RpcError> {
Ok(())
}
#[command(subcommands(
system::time,
system::experimental,
system::logs,
system::kernel_logs,
system::metrics,
shutdown::shutdown,
shutdown::restart,
shutdown::rebuild,
update::update_system,
))]
pub fn server() -> Result<(), RpcError> {
Ok(())
}
#[command(subcommands(
action::action,
install::install,
install::sideload,
install::uninstall,
install::list,
config::config,
control::start,
control::stop,
control::restart,
logs::logs,
properties::properties,
dependencies::dependency,
backup::package_backup,
))]
pub fn package() -> Result<(), RpcError> {
Ok(())
}
#[command(subcommands(
version::git_info,
s9pk::pack,
developer::verify,
developer::init,
inspect::inspect,
registry::admin::publish,
))]
pub fn portable_api() -> Result<(), RpcError> {
Ok(())
}
#[command(subcommands(version::git_info, echo, diagnostic::diagnostic))]
pub fn diagnostic_api() -> Result<(), RpcError> {
Ok(())
}
#[command(subcommands(version::git_info, echo, setup::setup))]
pub fn setup_api() -> Result<(), RpcError> {
Ok(())
}
#[command(subcommands(version::git_info, echo, os_install::install))]
pub fn install_api() -> Result<(), RpcError> {
Ok(())
}

551
core/startos/src/logs.rs Normal file
View File

@@ -0,0 +1,551 @@
use std::future::Future;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::process::Stdio;
use std::time::{Duration, UNIX_EPOCH};
use chrono::{DateTime, Utc};
use color_eyre::eyre::eyre;
use futures::stream::BoxStream;
use futures::{FutureExt, SinkExt, Stream, StreamExt, TryStreamExt};
use hyper::upgrade::Upgraded;
use hyper::Error as HyperError;
use rpc_toolkit::command;
use rpc_toolkit::yajrc::RpcError;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::{Child, Command};
use tokio::task::JoinError;
use tokio_stream::wrappers::LinesStream;
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
use tokio_tungstenite::tungstenite::protocol::CloseFrame;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::WebSocketStream;
use tracing::instrument;
use crate::context::{CliContext, RpcContext};
use crate::core::rpc_continuations::{RequestGuid, RpcContinuation};
use crate::error::ResultExt;
use crate::procedure::docker::DockerProcedure;
use crate::s9pk::manifest::PackageId;
use crate::util::display_none;
use crate::util::serde::Reversible;
use crate::{Error, ErrorKind};
#[pin_project::pin_project]
pub struct LogStream {
_child: Child,
#[pin]
entries: BoxStream<'static, Result<JournalctlEntry, Error>>,
}
impl Deref for LogStream {
type Target = BoxStream<'static, Result<JournalctlEntry, Error>>;
fn deref(&self) -> &Self::Target {
&self.entries
}
}
impl DerefMut for LogStream {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.entries
}
}
impl Stream for LogStream {
type Item = Result<JournalctlEntry, Error>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.project();
Stream::poll_next(this.entries, cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.entries.size_hint()
}
}
#[instrument(skip_all)]
async fn ws_handler<
WSFut: Future<Output = Result<Result<WebSocketStream<Upgraded>, HyperError>, JoinError>>,
>(
first_entry: Option<LogEntry>,
mut logs: LogStream,
ws_fut: WSFut,
) -> Result<(), Error> {
let mut stream = ws_fut
.await
.with_kind(crate::ErrorKind::Network)?
.with_kind(crate::ErrorKind::Unknown)?;
if let Some(first_entry) = first_entry {
stream
.send(Message::Text(
serde_json::to_string(&first_entry).with_kind(ErrorKind::Serialization)?,
))
.await
.with_kind(ErrorKind::Network)?;
}
let mut ws_closed = false;
while let Some(entry) = tokio::select! {
a = logs.try_next() => Some(a?),
a = stream.try_next() => { a.with_kind(crate::ErrorKind::Network)?; ws_closed = true; None }
} {
if let Some(entry) = entry {
let (_, log_entry) = entry.log_entry()?;
stream
.send(Message::Text(
serde_json::to_string(&log_entry).with_kind(ErrorKind::Serialization)?,
))
.await
.with_kind(ErrorKind::Network)?;
}
}
if !ws_closed {
stream
.close(Some(CloseFrame {
code: CloseCode::Normal,
reason: "Log Stream Finished".into(),
}))
.await
.with_kind(ErrorKind::Network)?;
}
Ok(())
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct LogResponse {
entries: Reversible<LogEntry>,
start_cursor: Option<String>,
end_cursor: Option<String>,
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
#[serde(rename_all = "kebab-case")]
pub struct LogFollowResponse {
start_cursor: Option<String>,
guid: RequestGuid,
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
pub struct LogEntry {
timestamp: DateTime<Utc>,
message: String,
}
impl std::fmt::Display for LogEntry {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"{} {}",
self.timestamp
.to_rfc3339_opts(chrono::SecondsFormat::Millis, true),
self.message
)
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct JournalctlEntry {
#[serde(rename = "__REALTIME_TIMESTAMP")]
pub timestamp: String,
#[serde(rename = "MESSAGE")]
#[serde(deserialize_with = "deserialize_log_message")]
pub message: String,
#[serde(rename = "__CURSOR")]
pub cursor: String,
}
impl JournalctlEntry {
fn log_entry(self) -> Result<(String, LogEntry), Error> {
Ok((
self.cursor,
LogEntry {
timestamp: DateTime::<Utc>::from(
UNIX_EPOCH + Duration::from_micros(self.timestamp.parse::<u64>()?),
),
message: self.message,
},
))
}
}
fn deserialize_log_message<'de, D: serde::de::Deserializer<'de>>(
deserializer: D,
) -> std::result::Result<String, D::Error> {
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = String;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "a parsable string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(v.trim().to_owned())
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(String::new())
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
String::from_utf8(
std::iter::repeat_with(|| seq.next_element::<u8>().transpose())
.take_while(|a| a.is_some())
.flatten()
.collect::<Result<Vec<u8>, _>>()?,
)
.map(|s| s.trim().to_owned())
.map_err(serde::de::Error::custom)
}
}
deserializer.deserialize_any(Visitor)
}
/// Defining how we are going to filter on a journalctl cli log.
/// Kernal: (-k --dmesg Show kernel message log from the current boot)
/// Unit: ( -u --unit=UNIT Show logs from the specified unit
/// --user-unit=UNIT Show logs from the specified user unit))
/// System: Unit is startd, but we also filter on the comm
/// Container: Filtering containers, like podman/docker is done by filtering on the CONTAINER_NAME
#[derive(Debug)]
pub enum LogSource {
Kernel,
Unit(&'static str),
System,
Container(PackageId),
}
pub const SYSTEM_UNIT: &str = "startd";
#[command(
custom_cli(cli_logs(async, context(CliContext))),
subcommands(self(logs_nofollow(async)), logs_follow),
display(display_none)
)]
pub async fn logs(
#[arg] id: PackageId,
#[arg(short = 'l', long = "limit")] limit: Option<usize>,
#[arg(short = 'c', long = "cursor")] cursor: Option<String>,
#[arg(short = 'B', long = "before", default)] before: bool,
#[arg(short = 'f', long = "follow", default)] follow: bool,
) -> Result<(PackageId, Option<usize>, Option<String>, bool, bool), Error> {
Ok((id, limit, cursor, before, follow))
}
pub async fn cli_logs(
ctx: CliContext,
(id, limit, cursor, before, follow): (PackageId, Option<usize>, Option<String>, bool, bool),
) -> Result<(), RpcError> {
if follow {
if cursor.is_some() {
return Err(RpcError::from(Error::new(
eyre!("The argument '--cursor <cursor>' cannot be used with '--follow'"),
crate::ErrorKind::InvalidRequest,
)));
}
if before {
return Err(RpcError::from(Error::new(
eyre!("The argument '--before' cannot be used with '--follow'"),
crate::ErrorKind::InvalidRequest,
)));
}
cli_logs_generic_follow(ctx, "package.logs.follow", Some(id), limit).await
} else {
cli_logs_generic_nofollow(ctx, "package.logs", Some(id), limit, cursor, before).await
}
}
pub async fn logs_nofollow(
_ctx: (),
(id, limit, cursor, before, _): (PackageId, Option<usize>, Option<String>, bool, bool),
) -> Result<LogResponse, Error> {
fetch_logs(LogSource::Container(id), limit, cursor, before).await
}
#[command(rpc_only, rename = "follow", display(display_none))]
pub async fn logs_follow(
#[context] ctx: RpcContext,
#[parent_data] (id, limit, _, _, _): (PackageId, Option<usize>, Option<String>, bool, bool),
) -> Result<LogFollowResponse, Error> {
follow_logs(ctx, LogSource::Container(id), limit).await
}
pub async fn cli_logs_generic_nofollow(
ctx: CliContext,
method: &str,
id: Option<PackageId>,
limit: Option<usize>,
cursor: Option<String>,
before: bool,
) -> Result<(), RpcError> {
let res = rpc_toolkit::command_helpers::call_remote(
ctx.clone(),
method,
serde_json::json!({
"id": id,
"limit": limit,
"cursor": cursor,
"before": before,
}),
PhantomData::<LogResponse>,
)
.await?
.result?;
for entry in res.entries.iter() {
println!("{}", entry);
}
Ok(())
}
pub async fn cli_logs_generic_follow(
ctx: CliContext,
method: &str,
id: Option<PackageId>,
limit: Option<usize>,
) -> Result<(), RpcError> {
let res = rpc_toolkit::command_helpers::call_remote(
ctx.clone(),
method,
serde_json::json!({
"id": id,
"limit": limit,
}),
PhantomData::<LogFollowResponse>,
)
.await?
.result?;
let mut base_url = ctx.base_url.clone();
let ws_scheme = match base_url.scheme() {
"https" => "wss",
"http" => "ws",
_ => {
return Err(Error::new(
eyre!("Cannot parse scheme from base URL"),
crate::ErrorKind::ParseUrl,
)
.into())
}
};
base_url
.set_scheme(ws_scheme)
.map_err(|_| Error::new(eyre!("Cannot set URL scheme"), crate::ErrorKind::ParseUrl))?;
let (mut stream, _) =
// base_url is "http://127.0.0.1/", with a trailing slash, so we don't put a leading slash in this path:
tokio_tungstenite::connect_async(format!("{}ws/rpc/{}", base_url, res.guid)).await?;
while let Some(log) = stream.try_next().await? {
if let Message::Text(log) = log {
println!("{}", serde_json::from_str::<LogEntry>(&log)?);
}
}
Ok(())
}
pub async fn journalctl(
id: LogSource,
limit: usize,
cursor: Option<&str>,
before: bool,
follow: bool,
) -> Result<LogStream, Error> {
let mut cmd = Command::new("journalctl");
cmd.kill_on_drop(true);
cmd.arg("--output=json");
cmd.arg("--output-fields=MESSAGE");
cmd.arg(format!("-n{}", limit));
match id {
LogSource::Kernel => {
cmd.arg("-k");
}
LogSource::Unit(id) => {
cmd.arg("-u");
cmd.arg(id);
}
LogSource::System => {
cmd.arg("-u");
cmd.arg(SYSTEM_UNIT);
cmd.arg(format!("_COMM={}", SYSTEM_UNIT));
}
LogSource::Container(id) => {
#[cfg(not(feature = "docker"))]
cmd.arg(format!(
"SYSLOG_IDENTIFIER={}",
DockerProcedure::container_name(&id, None)
));
#[cfg(feature = "docker")]
cmd.arg(format!(
"CONTAINER_NAME={}",
DockerProcedure::container_name(&id, None)
));
}
};
let cursor_formatted = format!("--after-cursor={}", cursor.unwrap_or(""));
if cursor.is_some() {
cmd.arg(&cursor_formatted);
if before {
cmd.arg("--reverse");
}
}
if follow {
cmd.arg("--follow");
}
let mut child = cmd.stdout(Stdio::piped()).spawn()?;
let out = BufReader::new(
child
.stdout
.take()
.ok_or_else(|| Error::new(eyre!("No stdout available"), crate::ErrorKind::Journald))?,
);
let journalctl_entries = LinesStream::new(out.lines());
let deserialized_entries = journalctl_entries
.map_err(|e| Error::new(e, crate::ErrorKind::Journald))
.and_then(|s| {
futures::future::ready(
serde_json::from_str::<JournalctlEntry>(&s)
.with_kind(crate::ErrorKind::Deserialization),
)
});
Ok(LogStream {
_child: child,
entries: deserialized_entries.boxed(),
})
}
#[instrument(skip_all)]
pub async fn fetch_logs(
id: LogSource,
limit: Option<usize>,
cursor: Option<String>,
before: bool,
) -> Result<LogResponse, Error> {
let limit = limit.unwrap_or(50);
let mut stream = journalctl(id, limit, cursor.as_deref(), before, false).await?;
let mut entries = Vec::with_capacity(limit);
let mut start_cursor = None;
if let Some(first) = tokio::time::timeout(Duration::from_secs(1), stream.try_next())
.await
.ok()
.transpose()?
.flatten()
{
let (cursor, entry) = first.log_entry()?;
start_cursor = Some(cursor);
entries.push(entry);
}
let (mut end_cursor, entries) = stream
.try_fold(
(start_cursor.clone(), entries),
|(_, mut acc), entry| async move {
let (cursor, entry) = entry.log_entry()?;
acc.push(entry);
Ok((Some(cursor), acc))
},
)
.await?;
let mut entries = Reversible::new(entries);
// reverse again so output is always in increasing chronological order
if cursor.is_some() && before {
entries.reverse();
std::mem::swap(&mut start_cursor, &mut end_cursor);
}
Ok(LogResponse {
entries,
start_cursor,
end_cursor,
})
}
#[instrument(skip_all)]
pub async fn follow_logs(
ctx: RpcContext,
id: LogSource,
limit: Option<usize>,
) -> Result<LogFollowResponse, Error> {
let limit = limit.unwrap_or(50);
let mut stream = journalctl(id, limit, None, false, true).await?;
let mut start_cursor = None;
let mut first_entry = None;
if let Some(first) = tokio::time::timeout(Duration::from_secs(1), stream.try_next())
.await
.ok()
.transpose()?
.flatten()
{
let (cursor, entry) = first.log_entry()?;
start_cursor = Some(cursor);
first_entry = Some(entry);
}
let guid = RequestGuid::new();
ctx.add_continuation(
guid.clone(),
RpcContinuation::ws(
Box::new(move |ws_fut| ws_handler(first_entry, stream, ws_fut).boxed()),
Duration::from_secs(30),
),
)
.await;
Ok(LogFollowResponse { start_cursor, guid })
}
// #[tokio::test]
// pub async fn test_logs() {
// let response = fetch_logs(
// // change `tor.service` to an actual journald unit on your machine
// // LogSource::Service("tor.service"),
// // first run `docker run --name=hello-world.embassy --log-driver=journald hello-world`
// LogSource::Container("hello-world".parse().unwrap()),
// // Some(5),
// None,
// None,
// // Some("s=1b8c418e28534400856c27b211dd94fd;i=5a7;b=97571c13a1284f87bc0639b5cff5acbe;m=740e916;t=5ca073eea3445;x=f45bc233ca328348".to_owned()),
// false,
// true,
// )
// .await
// .unwrap();
// let serialized = serde_json::to_string_pretty(&response).unwrap();
// println!("{}", serialized);
// }
// #[tokio::test]
// pub async fn test_logs() {
// let mut cmd = Command::new("journalctl");
// cmd.kill_on_drop(true);
// cmd.arg("-f");
// cmd.arg("CONTAINER_NAME=hello-world.embassy");
// let mut child = cmd.stdout(Stdio::piped()).spawn().unwrap();
// let out = BufReader::new(
// child
// .stdout
// .take()
// .ok_or_else(|| Error::new(eyre!("No stdout available"), crate::ErrorKind::Journald))
// .unwrap(),
// );
// let mut journalctl_entries = LinesStream::new(out.lines());
// while let Some(line) = journalctl_entries.try_next().await.unwrap() {
// dbg!(line);
// }
// }

3
core/startos/src/main.rs Normal file
View File

@@ -0,0 +1,3 @@
fn main() {
startos::bins::startbox()
}

View File

@@ -0,0 +1,56 @@
use models::OptionExt;
use tracing::instrument;
use crate::context::RpcContext;
use crate::prelude::*;
use crate::s9pk::manifest::PackageId;
use crate::status::MainStatus;
use crate::Error;
/// So, this is used for a service to run a health check cycle, go out and run the health checks, and store those in the db
#[instrument(skip_all)]
pub async fn check(ctx: &RpcContext, id: &PackageId) -> Result<(), Error> {
let (manifest, started) = {
let peeked = ctx.db.peek().await;
let pde = peeked
.as_package_data()
.as_idx(id)
.or_not_found(id)?
.expect_as_installed()?;
let manifest = pde.as_installed().as_manifest().de()?;
let started = pde.as_installed().as_status().as_main().de()?.started();
(manifest, started)
};
let health_results = if let Some(started) = started {
tracing::debug!("Checking health of {}", id);
manifest
.health_checks
.check_all(ctx, started, id, &manifest.version, &manifest.volumes)
.await?
} else {
return Ok(());
};
ctx.db
.mutate(|v| {
let pde = v
.as_package_data_mut()
.as_idx_mut(id)
.or_not_found(id)?
.expect_as_installed_mut()?;
let status = pde.as_installed_mut().as_status_mut().as_main_mut();
if let MainStatus::Running { health: _, started } = status.de()? {
status.ser(&MainStatus::Running {
health: health_results.clone(),
started,
})?;
}
Ok(())
})
.await
}

View File

@@ -0,0 +1,300 @@
use std::sync::Arc;
use std::time::Duration;
use models::OptionExt;
use tokio::sync::watch;
use tokio::sync::watch::Sender;
use tracing::instrument;
use super::start_stop::StartStop;
use super::{manager_seed, run_main, ManagerPersistentContainer, RunMainResult};
use crate::prelude::*;
use crate::procedure::NoOutput;
use crate::s9pk::manifest::Manifest;
use crate::status::MainStatus;
use crate::util::NonDetachingJoinHandle;
use crate::Error;
pub type ManageContainerOverride = Arc<watch::Sender<Option<Override>>>;
pub type Override = MainStatus;
pub struct OverrideGuard {
override_main_status: Option<ManageContainerOverride>,
}
impl OverrideGuard {
pub fn drop(self) {}
}
impl Drop for OverrideGuard {
fn drop(&mut self) {
if let Some(override_main_status) = self.override_main_status.take() {
override_main_status.send_modify(|x| {
*x = None;
});
}
}
}
/// This is the thing describing the state machine actor for a service
/// state and current running/ desired states.
pub struct ManageContainer {
pub(super) current_state: Arc<watch::Sender<StartStop>>,
pub(super) desired_state: Arc<watch::Sender<StartStop>>,
_service: NonDetachingJoinHandle<()>,
_save_state: NonDetachingJoinHandle<()>,
override_main_status: ManageContainerOverride,
}
impl ManageContainer {
pub async fn new(
seed: Arc<manager_seed::ManagerSeed>,
persistent_container: ManagerPersistentContainer,
) -> Result<Self, Error> {
let current_state = Arc::new(watch::channel(StartStop::Stop).0);
let desired_state = Arc::new(
watch::channel::<StartStop>(
get_status(seed.ctx.db.peek().await, &seed.manifest).into(),
)
.0,
);
let override_main_status: ManageContainerOverride = Arc::new(watch::channel(None).0);
let service = tokio::spawn(create_service_manager(
desired_state.clone(),
seed.clone(),
current_state.clone(),
persistent_container,
))
.into();
let save_state = tokio::spawn(save_state(
desired_state.clone(),
current_state.clone(),
override_main_status.clone(),
seed.clone(),
))
.into();
Ok(ManageContainer {
current_state,
desired_state,
_service: service,
override_main_status,
_save_state: save_state,
})
}
/// Set override is used during something like a restart of a service. We want to show certain statuses be different
/// from the actual status of the service.
pub fn set_override(&self, override_status: Override) -> Result<OverrideGuard, Error> {
let status = Some(override_status);
if self.override_main_status.borrow().is_some() {
return Err(Error::new(
eyre!("Already have an override"),
ErrorKind::InvalidRequest,
));
}
self.override_main_status
.send_modify(|x| *x = status.clone());
Ok(OverrideGuard {
override_main_status: Some(self.override_main_status.clone()),
})
}
/// Set the override, but don't have a guard to revert it. Used only on the mananger to do a shutdown.
pub(super) async fn lock_state_forever(
&self,
seed: &manager_seed::ManagerSeed,
) -> Result<(), Error> {
let current_state = get_status(seed.ctx.db.peek().await, &seed.manifest);
self.override_main_status
.send_modify(|x| *x = Some(current_state));
Ok(())
}
/// We want to set the state of the service, like to start or stop
pub fn to_desired(&self, new_state: StartStop) {
self.desired_state.send_modify(|x| *x = new_state);
}
/// This is a tool to say wait for the service to be in a certain state.
pub async fn wait_for_desired(&self, new_state: StartStop) {
let mut current_state = self.current_state();
self.to_desired(new_state);
while *current_state.borrow() != new_state {
current_state.changed().await.unwrap_or_default();
}
}
/// Getter
pub fn current_state(&self) -> watch::Receiver<StartStop> {
self.current_state.subscribe()
}
/// Getter
pub fn desired_state(&self) -> watch::Receiver<StartStop> {
self.desired_state.subscribe()
}
}
async fn create_service_manager(
desired_state: Arc<Sender<StartStop>>,
seed: Arc<manager_seed::ManagerSeed>,
current_state: Arc<Sender<StartStop>>,
persistent_container: Arc<Option<super::persistent_container::PersistentContainer>>,
) {
let mut desired_state_receiver = desired_state.subscribe();
let mut running_service: Option<NonDetachingJoinHandle<()>> = None;
let seed = seed.clone();
loop {
let current: StartStop = *current_state.borrow();
let desired: StartStop = *desired_state_receiver.borrow();
match (current, desired) {
(StartStop::Start, StartStop::Start) => (),
(StartStop::Start, StartStop::Stop) => {
if persistent_container.is_none() {
if let Err(err) = seed.stop_container().await {
tracing::error!("Could not stop container");
tracing::debug!("{:?}", err)
}
running_service = None;
} else if let Some(current_service) = running_service.take() {
tokio::select! {
_ = current_service => (),
_ = tokio::time::sleep(Duration::from_secs_f64(seed.manifest
.containers
.as_ref()
.and_then(|c| c.main.sigterm_timeout).map(|x| x.as_secs_f64()).unwrap_or_default())) => {
tracing::error!("Could not stop service");
}
}
}
current_state.send_modify(|x| *x = StartStop::Stop);
}
(StartStop::Stop, StartStop::Start) => starting_service(
current_state.clone(),
desired_state.clone(),
seed.clone(),
persistent_container.clone(),
&mut running_service,
),
(StartStop::Stop, StartStop::Stop) => (),
}
if desired_state_receiver.changed().await.is_err() {
tracing::error!("Desired state error");
break;
}
}
}
async fn save_state(
desired_state: Arc<Sender<StartStop>>,
current_state: Arc<Sender<StartStop>>,
override_main_status: ManageContainerOverride,
seed: Arc<manager_seed::ManagerSeed>,
) {
let mut desired_state_receiver = desired_state.subscribe();
let mut current_state_receiver = current_state.subscribe();
let mut override_main_status_receiver = override_main_status.subscribe();
loop {
let current: StartStop = *current_state_receiver.borrow();
let desired: StartStop = *desired_state_receiver.borrow();
let override_status = override_main_status_receiver.borrow().clone();
let status = match (override_status.clone(), current, desired) {
(Some(status), _, _) => status,
(_, StartStop::Start, StartStop::Start) => MainStatus::Running {
started: chrono::Utc::now(),
health: Default::default(),
},
(_, StartStop::Start, StartStop::Stop) => MainStatus::Stopping,
(_, StartStop::Stop, StartStop::Start) => MainStatus::Starting,
(_, StartStop::Stop, StartStop::Stop) => MainStatus::Stopped,
};
let manifest = &seed.manifest;
if let Err(err) = seed
.ctx
.db
.mutate(|db| set_status(db, manifest, &status))
.await
{
tracing::error!("Did not set status for {}", seed.container_name);
tracing::debug!("{:?}", err);
}
tokio::select! {
_ = desired_state_receiver.changed() =>{},
_ = current_state_receiver.changed() => {},
_ = override_main_status_receiver.changed() => {}
}
}
}
fn starting_service(
current_state: Arc<Sender<StartStop>>,
desired_state: Arc<Sender<StartStop>>,
seed: Arc<manager_seed::ManagerSeed>,
persistent_container: ManagerPersistentContainer,
running_service: &mut Option<NonDetachingJoinHandle<()>>,
) {
let set_running = {
let current_state = current_state.clone();
Arc::new(move || {
current_state.send_modify(|x| *x = StartStop::Start);
})
};
let set_stopped = { move || current_state.send_modify(|x| *x = StartStop::Stop) };
let running_main_loop = async move {
while desired_state.borrow().is_start() {
let result = run_main(
seed.clone(),
persistent_container.clone(),
set_running.clone(),
)
.await;
set_stopped();
run_main_log_result(result, seed.clone()).await;
}
};
*running_service = Some(tokio::spawn(running_main_loop).into());
}
async fn run_main_log_result(result: RunMainResult, seed: Arc<manager_seed::ManagerSeed>) {
match result {
Ok(Ok(NoOutput)) => (), // restart
Ok(Err(e)) => {
tracing::error!(
"The service {} has crashed with the following exit code: {}",
seed.manifest.id.clone(),
e.0
);
tokio::time::sleep(Duration::from_secs(15)).await;
}
Err(e) => {
tracing::error!("failed to start service: {}", e);
tracing::debug!("{:?}", e);
}
}
}
/// Used only in the mod where we are doing a backup
#[instrument(skip(db, manifest))]
pub(super) fn get_status(db: Peeked, manifest: &Manifest) -> MainStatus {
db.as_package_data()
.as_idx(&manifest.id)
.and_then(|x| x.as_installed())
.filter(|x| x.as_manifest().as_version().de().ok() == Some(manifest.version.clone()))
.and_then(|x| x.as_status().as_main().de().ok())
.unwrap_or(MainStatus::Stopped)
}
#[instrument(skip(db, manifest))]
fn set_status(db: &mut Peeked, manifest: &Manifest, main_status: &MainStatus) -> Result<(), Error> {
let Some(installed) = db
.as_package_data_mut()
.as_idx_mut(&manifest.id)
.or_not_found(&manifest.id)?
.as_installed_mut()
else {
return Ok(());
};
installed.as_status_mut().as_main_mut().ser(main_status)
}

View File

@@ -0,0 +1,96 @@
use std::collections::BTreeMap;
use std::sync::Arc;
use color_eyre::eyre::eyre;
use tokio::sync::RwLock;
use tracing::instrument;
use super::Manager;
use crate::context::RpcContext;
use crate::prelude::*;
use crate::s9pk::manifest::{Manifest, PackageId};
use crate::util::Version;
use crate::Error;
/// This is the structure to contain all the service managers
#[derive(Default)]
pub struct ManagerMap(RwLock<BTreeMap<(PackageId, Version), Arc<Manager>>>);
impl ManagerMap {
#[instrument(skip_all)]
pub async fn init(&self, ctx: RpcContext, peeked: Peeked) -> Result<(), Error> {
let mut res = BTreeMap::new();
for package in peeked.as_package_data().keys()? {
let man: Manifest = if let Some(manifest) = peeked
.as_package_data()
.as_idx(&package)
.and_then(|x| x.as_installed())
.map(|x| x.as_manifest().de())
{
manifest?
} else {
continue;
};
res.insert(
(package, man.version.clone()),
Arc::new(Manager::new(ctx.clone(), man).await?),
);
}
*self.0.write().await = res;
Ok(())
}
/// Used during the install process
#[instrument(skip_all)]
pub async fn add(&self, ctx: RpcContext, manifest: Manifest) -> Result<Arc<Manager>, Error> {
let mut lock = self.0.write().await;
let id = (manifest.id.clone(), manifest.version.clone());
if let Some(man) = lock.remove(&id) {
man.exit().await;
}
let manager = Arc::new(Manager::new(ctx.clone(), manifest).await?);
lock.insert(id, manager.clone());
Ok(manager)
}
/// This is ran during the cleanup, so when we are uninstalling the service
#[instrument(skip_all)]
pub async fn remove(&self, id: &(PackageId, Version)) {
if let Some(man) = self.0.write().await.remove(id) {
man.exit().await;
}
}
/// Used during a shutdown
#[instrument(skip_all)]
pub async fn empty(&self) -> Result<(), Error> {
let res =
futures::future::join_all(std::mem::take(&mut *self.0.write().await).into_iter().map(
|((id, version), man)| async move {
tracing::debug!("Manager for {}@{} shutting down", id, version);
man.shutdown().await?;
tracing::debug!("Manager for {}@{} is shutdown", id, version);
if let Err(e) = Arc::try_unwrap(man) {
tracing::trace!(
"Manager for {}@{} still has {} other open references",
id,
version,
Arc::strong_count(&e) - 1
);
}
Ok::<_, Error>(())
},
))
.await;
res.into_iter().fold(Ok(()), |res, x| match (res, x) {
(Ok(()), x) => x,
(Err(e), Ok(())) => Err(e),
(Err(e1), Err(e2)) => Err(Error::new(eyre!("{}, {}", e1.source, e2.source), e1.kind)),
})
}
#[instrument(skip_all)]
pub async fn get(&self, id: &(PackageId, Version)) -> Option<Arc<Manager>> {
self.0.read().await.get(id).cloned()
}
}

View File

@@ -0,0 +1,37 @@
use models::ErrorKind;
use crate::context::RpcContext;
use crate::procedure::docker::DockerProcedure;
use crate::procedure::PackageProcedure;
use crate::s9pk::manifest::Manifest;
use crate::util::docker::stop_container;
use crate::Error;
/// This is helper structure for a service, the seed of the data that is needed for the manager_container
pub struct ManagerSeed {
pub ctx: RpcContext,
pub manifest: Manifest,
pub container_name: String,
}
impl ManagerSeed {
pub async fn stop_container(&self) -> Result<(), Error> {
match stop_container(
&self.container_name,
match &self.manifest.main {
PackageProcedure::Docker(DockerProcedure {
sigterm_timeout: Some(sigterm_timeout),
..
}) => Some(**sigterm_timeout),
_ => None,
},
None,
)
.await
{
Err(e) if e.kind == ErrorKind::NotFound => (), // Already stopped
a => a?,
}
Ok(())
}
}

View File

@@ -0,0 +1,888 @@
use std::collections::{BTreeMap, BTreeSet};
use std::net::Ipv4Addr;
use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;
use color_eyre::eyre::eyre;
use container_init::ProcessGroupId;
use futures::future::BoxFuture;
use futures::{Future, FutureExt, TryFutureExt};
use helpers::UnixRpcClient;
use models::{ErrorKind, OptionExt, PackageId};
use nix::sys::signal::Signal;
use persistent_container::PersistentContainer;
use rand::SeedableRng;
use sqlx::Connection;
use start_stop::StartStop;
use tokio::sync::watch::{self, Sender};
use tokio::sync::{oneshot, Mutex};
use tracing::instrument;
use transition_state::TransitionState;
use crate::backup::target::PackageBackupInfo;
use crate::backup::PackageBackupReport;
use crate::config::action::ConfigRes;
use crate::config::spec::ValueSpecPointer;
use crate::config::ConfigureContext;
use crate::context::RpcContext;
use crate::db::model::{CurrentDependencies, CurrentDependencyInfo};
use crate::dependencies::{
add_dependent_to_current_dependents_lists, compute_dependency_config_errs,
};
use crate::disk::mount::backup::BackupMountGuard;
use crate::disk::mount::guard::TmpMountGuard;
use crate::install::cleanup::remove_from_current_dependents_lists;
use crate::net::net_controller::NetService;
use crate::net::vhost::AlpnInfo;
use crate::prelude::*;
use crate::procedure::docker::{DockerContainer, DockerProcedure, LongRunning};
use crate::procedure::{NoOutput, ProcedureName};
use crate::s9pk::manifest::Manifest;
use crate::status::MainStatus;
use crate::util::docker::{get_container_ip, kill_container};
use crate::util::NonDetachingJoinHandle;
use crate::volume::Volume;
use crate::Error;
pub mod health;
mod manager_container;
mod manager_map;
pub mod manager_seed;
mod persistent_container;
mod start_stop;
mod transition_state;
pub use manager_map::ManagerMap;
use self::manager_container::{get_status, ManageContainer};
use self::manager_seed::ManagerSeed;
pub const HEALTH_CHECK_COOLDOWN_SECONDS: u64 = 15;
pub const HEALTH_CHECK_GRACE_PERIOD_SECONDS: u64 = 5;
type ManagerPersistentContainer = Arc<Option<PersistentContainer>>;
type BackupGuard = Arc<Mutex<BackupMountGuard<TmpMountGuard>>>;
pub enum BackupReturn {
Error(Error),
AlreadyRunning(PackageBackupReport),
Ran {
report: PackageBackupReport,
res: Result<PackageBackupInfo, Error>,
},
}
pub struct Gid {
next_gid: (watch::Sender<u32>, watch::Receiver<u32>),
main_gid: (
watch::Sender<ProcessGroupId>,
watch::Receiver<ProcessGroupId>,
),
}
impl Default for Gid {
fn default() -> Self {
Self {
next_gid: watch::channel(1),
main_gid: watch::channel(ProcessGroupId(1)),
}
}
}
impl Gid {
pub fn new_gid(&self) -> ProcessGroupId {
let mut previous = 0;
self.next_gid.0.send_modify(|x| {
previous = *x;
*x = previous + 1;
});
ProcessGroupId(previous)
}
pub fn new_main_gid(&self) -> ProcessGroupId {
let gid = self.new_gid();
self.main_gid.0.send(gid).unwrap_or_default();
gid
}
}
/// This is the controller of the services. Here is where we can control a service with a start, stop, restart, etc.
#[derive(Clone)]
pub struct Manager {
seed: Arc<ManagerSeed>,
manage_container: Arc<manager_container::ManageContainer>,
transition: Arc<watch::Sender<TransitionState>>,
persistent_container: ManagerPersistentContainer,
pub gid: Arc<Gid>,
}
impl Manager {
pub async fn new(ctx: RpcContext, manifest: Manifest) -> Result<Self, Error> {
let seed = Arc::new(ManagerSeed {
ctx,
container_name: DockerProcedure::container_name(&manifest.id, None),
manifest,
});
let persistent_container = Arc::new(PersistentContainer::init(&seed).await?);
let manage_container = Arc::new(
manager_container::ManageContainer::new(seed.clone(), persistent_container.clone())
.await?,
);
let (transition, _) = watch::channel(Default::default());
let transition = Arc::new(transition);
Ok(Self {
seed,
manage_container,
transition,
persistent_container,
gid: Default::default(),
})
}
/// awaiting this does not wait for the start to complete
pub async fn start(&self) {
if self._is_transition_restart() {
return;
}
self._transition_abort().await;
self.manage_container.to_desired(StartStop::Start);
}
/// awaiting this does not wait for the stop to complete
pub async fn stop(&self) {
self._transition_abort().await;
self.manage_container.to_desired(StartStop::Stop);
}
/// awaiting this does not wait for the restart to complete
pub async fn restart(&self) {
if self._is_transition_restart()
&& *self.manage_container.desired_state().borrow() == StartStop::Stop
{
return;
}
if self.manage_container.desired_state().borrow().is_start() {
self._transition_replace(self._transition_restart()).await;
}
}
/// awaiting this does not wait for the restart to complete
pub async fn configure(
&self,
configure_context: ConfigureContext,
) -> Result<BTreeMap<PackageId, String>, Error> {
if self._is_transition_restart() {
self._transition_abort().await;
} else if self._is_transition_backup() {
return Err(Error::new(
eyre!("Can't configure because service is backing up"),
ErrorKind::InvalidRequest,
));
}
let context = self.seed.ctx.clone();
let id = self.seed.manifest.id.clone();
let breakages = configure(context, id, configure_context).await?;
self.restart().await;
Ok(breakages)
}
/// awaiting this does not wait for the backup to complete
pub async fn backup(&self, backup_guard: BackupGuard) -> BackupReturn {
if self._is_transition_backup() {
return BackupReturn::AlreadyRunning(PackageBackupReport {
error: Some("Can't do backup because service is already backing up".to_owned()),
});
}
let (transition_state, done) = self._transition_backup(backup_guard);
self._transition_replace(transition_state).await;
done.await
}
pub async fn exit(&self) {
self._transition_abort().await;
self.manage_container
.wait_for_desired(StartStop::Stop)
.await;
}
/// A special exit that is overridden the start state, should only be called in the shutdown, where we remove other containers
async fn shutdown(&self) -> Result<(), Error> {
self.manage_container.lock_state_forever(&self.seed).await?;
self.exit().await;
Ok(())
}
/// Used when we want to shutdown the service
pub async fn signal(&self, signal: Signal) -> Result<(), Error> {
let gid = self.gid.clone();
send_signal(self, gid, signal).await
}
/// Used as a getter, but also used in procedure
pub fn rpc_client(&self) -> Option<Arc<UnixRpcClient>> {
(*self.persistent_container)
.as_ref()
.map(|x| x.rpc_client())
}
async fn _transition_abort(&self) {
self.transition
.send_replace(Default::default())
.abort()
.await;
}
async fn _transition_replace(&self, transition_state: TransitionState) {
self.transition.send_replace(transition_state).abort().await;
}
pub(super) fn perform_restart(&self) -> impl Future<Output = Result<(), Error>> + 'static {
let manage_container = self.manage_container.clone();
async move {
let restart_override = manage_container.set_override(MainStatus::Restarting)?;
manage_container.wait_for_desired(StartStop::Stop).await;
manage_container.wait_for_desired(StartStop::Start).await;
restart_override.drop();
Ok(())
}
}
fn _transition_restart(&self) -> TransitionState {
let transition = self.transition.clone();
let restart = self.perform_restart();
TransitionState::Restarting(
tokio::spawn(async move {
if let Err(err) = restart.await {
tracing::error!("Error restarting service: {}", err);
}
transition.send_replace(Default::default());
})
.into(),
)
}
fn perform_backup(
&self,
backup_guard: BackupGuard,
) -> impl Future<Output = Result<Result<PackageBackupInfo, Error>, Error>> {
let manage_container = self.manage_container.clone();
let seed = self.seed.clone();
async move {
let peek = seed.ctx.db.peek().await;
let state_reverter = DesiredStateReverter::new(manage_container.clone());
let override_guard =
manage_container.set_override(get_status(peek, &seed.manifest).backing_up())?;
manage_container.wait_for_desired(StartStop::Stop).await;
let backup_guard = backup_guard.lock().await;
let guard = backup_guard.mount_package_backup(&seed.manifest.id).await?;
let return_value = seed.manifest.backup.create(seed.clone()).await;
guard.unmount().await?;
drop(backup_guard);
let manifest_id = seed.manifest.id.clone();
seed.ctx
.db
.mutate(|db| {
if let Some(progress) = db
.as_server_info_mut()
.as_status_info_mut()
.as_backup_progress_mut()
.transpose_mut()
.and_then(|p| p.as_idx_mut(&manifest_id))
{
progress.as_complete_mut().ser(&true)?;
}
Ok(())
})
.await?;
state_reverter.revert().await;
override_guard.drop();
Ok::<_, Error>(return_value)
}
}
fn _transition_backup(
&self,
backup_guard: BackupGuard,
) -> (TransitionState, BoxFuture<BackupReturn>) {
let (send, done) = oneshot::channel();
let transition_state = self.transition.clone();
(
TransitionState::BackingUp(
tokio::spawn(
self.perform_backup(backup_guard)
.then(finish_up_backup_task(transition_state, send)),
)
.into(),
),
done.map_err(|err| Error::new(eyre!("Oneshot error: {err:?}"), ErrorKind::Unknown))
.map(flatten_backup_error)
.boxed(),
)
}
fn _is_transition_restart(&self) -> bool {
let transition = self.transition.borrow();
matches!(*transition, TransitionState::Restarting(_))
}
fn _is_transition_backup(&self) -> bool {
let transition = self.transition.borrow();
matches!(*transition, TransitionState::BackingUp(_))
}
}
#[instrument(skip_all)]
async fn configure(
ctx: RpcContext,
id: PackageId,
mut configure_context: ConfigureContext,
) -> Result<BTreeMap<PackageId, String>, Error> {
let db = ctx.db.peek().await;
let id = &id;
let ctx = &ctx;
let overrides = &mut configure_context.overrides;
// fetch data from db
let manifest = db
.as_package_data()
.as_idx(id)
.or_not_found(id)?
.as_manifest()
.de()?;
// get current config and current spec
let ConfigRes {
config: old_config,
spec,
} = manifest
.config
.as_ref()
.or_not_found("Manifest config")?
.get(ctx, id, &manifest.version, &manifest.volumes)
.await?;
// determine new config to use
let mut config = if let Some(config) = configure_context.config.or_else(|| old_config.clone()) {
config
} else {
spec.gen(
&mut rand::rngs::StdRng::from_entropy(),
&configure_context.timeout,
)?
};
spec.validate(&manifest)?;
spec.matches(&config)?; // check that new config matches spec
// TODO Commit or not?
spec.update(ctx, &manifest, overrides, &mut config).await?; // dereference pointers in the new config
let manifest = db
.as_package_data()
.as_idx(id)
.or_not_found(id)?
.as_installed()
.or_not_found(id)?
.as_manifest()
.de()?;
let dependencies = &manifest.dependencies;
let mut current_dependencies: CurrentDependencies = CurrentDependencies(
dependencies
.0
.iter()
.filter_map(|(id, info)| {
if info.requirement.required() {
Some((id.clone(), CurrentDependencyInfo::default()))
} else {
None
}
})
.collect(),
);
for ptr in spec.pointers(&config)? {
match ptr {
ValueSpecPointer::Package(pkg_ptr) => {
if let Some(info) = current_dependencies.0.get_mut(pkg_ptr.package_id()) {
info.pointers.insert(pkg_ptr);
} else {
let id = pkg_ptr.package_id().to_owned();
let mut pointers = BTreeSet::new();
pointers.insert(pkg_ptr);
current_dependencies.0.insert(
id,
CurrentDependencyInfo {
pointers,
health_checks: BTreeSet::new(),
},
);
}
}
ValueSpecPointer::System(_) => (),
}
}
let action = manifest.config.as_ref().or_not_found(id)?;
let version = &manifest.version;
let volumes = &manifest.volumes;
if !configure_context.dry_run {
// run config action
let res = action
.set(ctx, id, version, &dependencies, volumes, &config)
.await?;
// track dependencies with no pointers
for (package_id, health_checks) in res.depends_on.into_iter() {
if let Some(current_dependency) = current_dependencies.0.get_mut(&package_id) {
current_dependency.health_checks.extend(health_checks);
} else {
current_dependencies.0.insert(
package_id,
CurrentDependencyInfo {
pointers: BTreeSet::new(),
health_checks,
},
);
}
}
// track dependency health checks
current_dependencies = current_dependencies.map(|x| {
x.into_iter()
.filter(|(dep_id, _)| {
if dep_id != id && !manifest.dependencies.0.contains_key(dep_id) {
tracing::warn!("Illegal dependency specified: {}", dep_id);
false
} else {
true
}
})
.collect()
});
}
let dependency_config_errs =
compute_dependency_config_errs(&ctx, &db, &manifest, &current_dependencies, overrides)
.await?;
// cache current config for dependents
configure_context
.overrides
.insert(id.clone(), config.clone());
// handle dependents
let dependents = db
.as_package_data()
.as_idx(id)
.or_not_found(id)?
.as_installed()
.or_not_found(id)?
.as_current_dependents()
.de()?;
for (dependent, _dep_info) in dependents.0.iter().filter(|(dep_id, _)| dep_id != &id) {
// check if config passes dependent check
if let Some(cfg) = db
.as_package_data()
.as_idx(dependent)
.or_not_found(dependent)?
.as_installed()
.or_not_found(dependent)?
.as_manifest()
.as_dependencies()
.as_idx(id)
.or_not_found(id)?
.as_config()
.de()?
{
let manifest = db
.as_package_data()
.as_idx(dependent)
.or_not_found(dependent)?
.as_installed()
.or_not_found(dependent)?
.as_manifest()
.de()?;
if let Err(error) = cfg
.check(
ctx,
dependent,
&manifest.version,
&manifest.volumes,
id,
&config,
)
.await?
{
configure_context.breakages.insert(dependent.clone(), error);
}
}
}
if !configure_context.dry_run {
return ctx
.db
.mutate(move |db| {
remove_from_current_dependents_lists(db, id, &current_dependencies)?;
add_dependent_to_current_dependents_lists(db, id, &current_dependencies)?;
current_dependencies.0.remove(id);
for (dep, errs) in db
.as_package_data_mut()
.as_entries_mut()?
.into_iter()
.filter_map(|(id, pde)| {
pde.as_installed_mut()
.map(|i| (id, i.as_status_mut().as_dependency_config_errors_mut()))
})
{
errs.remove(id)?;
if let Some(err) = configure_context.breakages.get(&dep) {
errs.insert(id, err)?;
}
}
let installed = db
.as_package_data_mut()
.as_idx_mut(id)
.or_not_found(id)?
.as_installed_mut()
.or_not_found(id)?;
installed
.as_current_dependencies_mut()
.ser(&current_dependencies)?;
let status = installed.as_status_mut();
status.as_configured_mut().ser(&true)?;
status
.as_dependency_config_errors_mut()
.ser(&dependency_config_errs)?;
Ok(configure_context.breakages)
})
.await; // add new
}
Ok(configure_context.breakages)
}
struct DesiredStateReverter {
manage_container: Option<Arc<ManageContainer>>,
starting_state: StartStop,
}
impl DesiredStateReverter {
fn new(manage_container: Arc<ManageContainer>) -> Self {
let starting_state = *manage_container.desired_state().borrow();
let manage_container = Some(manage_container);
Self {
starting_state,
manage_container,
}
}
async fn revert(mut self) {
if let Some(mut current_state) = self._revert() {
while *current_state.borrow() != self.starting_state {
current_state.changed().await.unwrap();
}
}
}
fn _revert(&mut self) -> Option<watch::Receiver<StartStop>> {
if let Some(manage_container) = self.manage_container.take() {
manage_container.to_desired(self.starting_state);
return Some(manage_container.desired_state());
}
None
}
}
impl Drop for DesiredStateReverter {
fn drop(&mut self) {
self._revert();
}
}
type BackupDoneSender = oneshot::Sender<Result<PackageBackupInfo, Error>>;
fn finish_up_backup_task(
transition: Arc<Sender<TransitionState>>,
send: BackupDoneSender,
) -> impl FnOnce(Result<Result<PackageBackupInfo, Error>, Error>) -> BoxFuture<'static, ()> {
move |result| {
async move {
transition.send_replace(Default::default());
send.send(match result {
Ok(a) => a,
Err(e) => Err(e),
})
.unwrap_or_default();
}
.boxed()
}
}
fn response_to_report(response: &Result<PackageBackupInfo, Error>) -> PackageBackupReport {
PackageBackupReport {
error: response.as_ref().err().map(|e| e.to_string()),
}
}
fn flatten_backup_error(input: Result<Result<PackageBackupInfo, Error>, Error>) -> BackupReturn {
match input {
Ok(a) => BackupReturn::Ran {
report: response_to_report(&a),
res: a,
},
Err(err) => BackupReturn::Error(err),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Status {
Starting,
Running,
Stopped,
Paused,
Shutdown,
}
#[derive(Debug, Clone, Copy)]
pub enum OnStop {
Restart,
Sleep,
Exit,
}
type RunMainResult = Result<Result<NoOutput, (i32, String)>, Error>;
#[instrument(skip_all)]
async fn run_main(
seed: Arc<ManagerSeed>,
persistent_container: ManagerPersistentContainer,
started: Arc<impl Fn()>,
) -> RunMainResult {
let mut runtime = NonDetachingJoinHandle::from(tokio::spawn(start_up_image(seed.clone())));
let ip = match persistent_container.is_some() {
false => Some(match get_running_ip(&seed, &mut runtime).await {
GetRunningIp::Ip(x) => x,
GetRunningIp::Error(e) => return Err(e),
GetRunningIp::EarlyExit(x) => return Ok(x),
}),
true => None,
};
let svc = if let Some(ip) = ip {
let net = add_network_for_main(&seed, ip).await?;
started();
Some(net)
} else {
None
};
let health = main_health_check_daemon(seed.clone());
let res = tokio::select! {
a = runtime => a.map_err(|_| Error::new(eyre!("Manager runtime panicked!"), crate::ErrorKind::Docker)).and_then(|a| a),
_ = health => Err(Error::new(eyre!("Health check daemon exited!"), crate::ErrorKind::Unknown))
};
if let Some(svc) = svc {
remove_network_for_main(svc).await?;
}
res
}
/// We want to start up the manifest, but in this case we want to know that we have generated the certificates.
/// Note for _generated_certificate: Needed to know that before we start the state we have generated the certificate
async fn start_up_image(seed: Arc<ManagerSeed>) -> Result<Result<NoOutput, (i32, String)>, Error> {
seed.manifest
.main
.execute::<(), NoOutput>(
&seed.ctx,
&seed.manifest.id,
&seed.manifest.version,
ProcedureName::Main,
&seed.manifest.volumes,
None,
None,
)
.await
}
async fn long_running_docker(
seed: &ManagerSeed,
container: &DockerContainer,
) -> Result<(LongRunning, UnixRpcClient), Error> {
container
.long_running_execute(
&seed.ctx,
&seed.manifest.id,
&seed.manifest.version,
&seed.manifest.volumes,
)
.await
}
enum GetRunningIp {
Ip(Ipv4Addr),
Error(Error),
EarlyExit(Result<NoOutput, (i32, String)>),
}
async fn get_long_running_ip(seed: &ManagerSeed, runtime: &mut LongRunning) -> GetRunningIp {
loop {
match get_container_ip(&seed.container_name).await {
Ok(Some(ip_addr)) => return GetRunningIp::Ip(ip_addr),
Ok(None) => (),
Err(e) if e.kind == ErrorKind::NotFound => (),
Err(e) => return GetRunningIp::Error(e),
}
if let Poll::Ready(res) = futures::poll!(&mut runtime.running_output) {
match res {
Ok(_) => return GetRunningIp::EarlyExit(Ok(NoOutput)),
Err(_e) => {
return GetRunningIp::Error(Error::new(
eyre!("Manager runtime panicked!"),
crate::ErrorKind::Docker,
))
}
}
}
}
}
#[instrument(skip(seed))]
async fn add_network_for_main(
seed: &ManagerSeed,
ip: std::net::Ipv4Addr,
) -> Result<NetService, Error> {
let mut svc = seed
.ctx
.net_controller
.create_service(seed.manifest.id.clone(), ip)
.await?;
// DEPRECATED
let mut secrets = seed.ctx.secret_store.acquire().await?;
let mut tx = secrets.begin().await?;
for (id, interface) in &seed.manifest.interfaces.0 {
for (external, internal) in interface.lan_config.iter().flatten() {
svc.add_lan(
tx.as_mut(),
id.clone(),
external.0,
internal.internal,
Err(AlpnInfo::Specified(vec![])),
)
.await?;
}
for (external, internal) in interface.tor_config.iter().flat_map(|t| &t.port_mapping) {
svc.add_tor(tx.as_mut(), id.clone(), external.0, internal.0)
.await?;
}
}
for volume in seed.manifest.volumes.values() {
if let Volume::Certificate { interface_id } = volume {
svc.export_cert(tx.as_mut(), interface_id, ip.into())
.await?;
}
}
tx.commit().await?;
Ok(svc)
}
#[instrument(skip(svc))]
async fn remove_network_for_main(svc: NetService) -> Result<(), Error> {
svc.remove_all().await
}
async fn main_health_check_daemon(seed: Arc<ManagerSeed>) {
tokio::time::sleep(Duration::from_secs(HEALTH_CHECK_GRACE_PERIOD_SECONDS)).await;
loop {
if let Err(e) = health::check(&seed.ctx, &seed.manifest.id).await {
tracing::error!(
"Failed to run health check for {}: {}",
&seed.manifest.id,
e
);
tracing::debug!("{:?}", e);
}
tokio::time::sleep(Duration::from_secs(HEALTH_CHECK_COOLDOWN_SECONDS)).await;
}
}
type RuntimeOfCommand = NonDetachingJoinHandle<Result<Result<NoOutput, (i32, String)>, Error>>;
#[instrument(skip(seed, runtime))]
async fn get_running_ip(seed: &ManagerSeed, mut runtime: &mut RuntimeOfCommand) -> GetRunningIp {
loop {
match get_container_ip(&seed.container_name).await {
Ok(Some(ip_addr)) => return GetRunningIp::Ip(ip_addr),
Ok(None) => (),
Err(e) if e.kind == ErrorKind::NotFound => (),
Err(e) => return GetRunningIp::Error(e),
}
if let Poll::Ready(res) = futures::poll!(&mut runtime) {
match res {
Ok(Ok(response)) => return GetRunningIp::EarlyExit(response),
Err(e) => {
return GetRunningIp::Error(Error::new(
match e.try_into_panic() {
Ok(e) => {
eyre!(
"Manager runtime panicked: {}",
e.downcast_ref::<&'static str>().unwrap_or(&"UNKNOWN")
)
}
_ => eyre!("Manager runtime cancelled!"),
},
crate::ErrorKind::Docker,
))
}
Ok(Err(e)) => {
return GetRunningIp::Error(Error::new(
eyre!("Manager runtime returned error: {}", e),
crate::ErrorKind::Docker,
))
}
}
}
}
}
async fn send_signal(manager: &Manager, gid: Arc<Gid>, signal: Signal) -> Result<(), Error> {
// stop health checks from committing their results
// shared
// .commit_health_check_results
// .store(false, Ordering::SeqCst);
if let Some(rpc_client) = manager.rpc_client() {
let main_gid = *gid.main_gid.0.borrow();
let next_gid = gid.new_gid();
#[cfg(feature = "js-engine")]
if let Err(e) = crate::procedure::js_scripts::JsProcedure::default()
.execute::<_, NoOutput>(
&manager.seed.ctx.datadir,
&manager.seed.manifest.id,
&manager.seed.manifest.version,
ProcedureName::Signal,
&manager.seed.manifest.volumes,
Some(container_init::SignalGroupParams {
gid: main_gid,
signal: signal as u32,
}),
None, // TODO
next_gid,
Some(rpc_client),
)
.await?
{
tracing::error!("Failed to send js signal: {}", e.1);
tracing::debug!("{:?}", e);
}
} else {
// send signal to container
kill_container(&manager.seed.container_name, Some(signal))
.await
.or_else(|e| {
if e.kind == ErrorKind::NotFound {
Ok(())
} else {
Err(e)
}
})?;
}
Ok(())
}

View File

@@ -0,0 +1,101 @@
use std::sync::Arc;
use std::time::Duration;
use color_eyre::eyre::eyre;
use helpers::UnixRpcClient;
use tokio::sync::oneshot;
use tokio::sync::watch::{self, Receiver};
use tracing::instrument;
use super::manager_seed::ManagerSeed;
use super::{
add_network_for_main, get_long_running_ip, long_running_docker, remove_network_for_main,
GetRunningIp,
};
use crate::procedure::docker::DockerContainer;
use crate::util::NonDetachingJoinHandle;
use crate::Error;
/// Persistant container are the old containers that need to run all the time
/// The goal is that all services will be persistent containers, waiting to run the main system.
pub struct PersistentContainer {
_running_docker: NonDetachingJoinHandle<()>,
pub rpc_client: Receiver<Arc<UnixRpcClient>>,
}
impl PersistentContainer {
#[instrument(skip_all)]
pub async fn init(seed: &Arc<ManagerSeed>) -> Result<Option<Self>, Error> {
Ok(if let Some(containers) = &seed.manifest.containers {
let (running_docker, rpc_client) =
spawn_persistent_container(seed.clone(), containers.main.clone()).await?;
Some(Self {
_running_docker: running_docker,
rpc_client,
})
} else {
None
})
}
pub fn rpc_client(&self) -> Arc<UnixRpcClient> {
self.rpc_client.borrow().clone()
}
}
pub async fn spawn_persistent_container(
seed: Arc<ManagerSeed>,
container: DockerContainer,
) -> Result<(NonDetachingJoinHandle<()>, Receiver<Arc<UnixRpcClient>>), Error> {
let (send_inserter, inserter) = oneshot::channel();
Ok((
tokio::task::spawn(async move {
let mut inserter_send: Option<watch::Sender<Arc<UnixRpcClient>>> = None;
let mut send_inserter: Option<oneshot::Sender<Receiver<Arc<UnixRpcClient>>>> = Some(send_inserter);
loop {
if let Err(e) = async {
let (mut runtime, inserter) =
long_running_docker(&seed, &container).await?;
let ip = match get_long_running_ip(&seed, &mut runtime).await {
GetRunningIp::Ip(x) => x,
GetRunningIp::Error(e) => return Err(e),
GetRunningIp::EarlyExit(e) => {
tracing::error!("Early Exit");
tracing::debug!("{:?}", e);
return Ok(());
}
};
let svc = add_network_for_main(&seed, ip).await?;
if let Some(inserter_send) = inserter_send.as_mut() {
let _ = inserter_send.send(Arc::new(inserter));
} else {
let (s, r) = watch::channel(Arc::new(inserter));
inserter_send = Some(s);
if let Some(send_inserter) = send_inserter.take() {
let _ = send_inserter.send(r);
}
}
let res = tokio::select! {
a = runtime.running_output => a.map_err(|_| Error::new(eyre!("Manager runtime panicked!"), crate::ErrorKind::Docker)).map(|_| ()),
};
remove_network_for_main(svc).await?;
res
}.await {
tracing::error!("Error in persistent container: {}", e);
tracing::debug!("{:?}", e);
} else {
break;
}
tokio::time::sleep(Duration::from_millis(200)).await;
}
})
.into(),
inserter.await.map_err(|_| Error::new(eyre!("Container handle dropped before inserter sent"), crate::ErrorKind::Unknown))?,
))
}

View File

@@ -0,0 +1,32 @@
use crate::status::MainStatus;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum StartStop {
Start,
Stop,
}
impl StartStop {
pub(crate) fn is_start(&self) -> bool {
matches!(self, StartStop::Start)
}
}
impl From<MainStatus> for StartStop {
fn from(value: MainStatus) -> Self {
match value {
MainStatus::Stopped => StartStop::Stop,
MainStatus::Restarting => StartStop::Start,
MainStatus::Stopping => StartStop::Stop,
MainStatus::Starting => StartStop::Start,
MainStatus::Running {
started: _,
health: _,
} => StartStop::Start,
MainStatus::BackingUp { started, health: _ } if started.is_some() => StartStop::Start,
MainStatus::BackingUp {
started: _,
health: _,
} => StartStop::Stop,
}
}
}

View File

@@ -0,0 +1,35 @@
use helpers::NonDetachingJoinHandle;
/// Used only in the manager/mod and is used to keep track of the state of the manager during the
/// transitional states
pub(super) enum TransitionState {
BackingUp(NonDetachingJoinHandle<()>),
Restarting(NonDetachingJoinHandle<()>),
None,
}
impl TransitionState {
pub(super) fn take(&mut self) -> Self {
std::mem::take(self)
}
pub(super) fn into_join_handle(self) -> Option<NonDetachingJoinHandle<()>> {
Some(match self {
TransitionState::BackingUp(a) => a,
TransitionState::Restarting(a) => a,
TransitionState::None => return None,
})
}
pub(super) async fn abort(&mut self) {
if let Some(s) = self.take().into_join_handle() {
if s.wait_for_abort().await.is_ok() {
tracing::trace!("transition completed before abort");
}
}
}
}
impl Default for TransitionState {
fn default() -> Self {
TransitionState::None
}
}

View File

@@ -0,0 +1,284 @@
use std::borrow::Borrow;
use std::sync::Arc;
use std::time::{Duration, Instant};
use basic_cookies::Cookie;
use color_eyre::eyre::eyre;
use digest::Digest;
use futures::future::BoxFuture;
use futures::FutureExt;
use http::StatusCode;
use rpc_toolkit::command_helpers::prelude::RequestParts;
use rpc_toolkit::hyper::header::COOKIE;
use rpc_toolkit::hyper::http::Error as HttpError;
use rpc_toolkit::hyper::{Body, Request, Response};
use rpc_toolkit::rpc_server_helpers::{
noop4, to_response, DynMiddleware, DynMiddlewareStage2, DynMiddlewareStage3,
};
use rpc_toolkit::yajrc::RpcMethod;
use rpc_toolkit::Metadata;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use tokio::sync::Mutex;
use crate::context::RpcContext;
use crate::{Error, ResultExt};
pub const LOCAL_AUTH_COOKIE_PATH: &str = "/run/embassy/rpc.authcookie";
pub trait AsLogoutSessionId {
fn as_logout_session_id(self) -> String;
}
/// Will need to know when we have logged out from a route
#[derive(Serialize, Deserialize)]
pub struct HasLoggedOutSessions(());
impl HasLoggedOutSessions {
pub async fn new(
logged_out_sessions: impl IntoIterator<Item = impl AsLogoutSessionId>,
ctx: &RpcContext,
) -> Result<Self, Error> {
let mut open_authed_websockets = ctx.open_authed_websockets.lock().await;
let mut sqlx_conn = ctx.secret_store.acquire().await?;
for session in logged_out_sessions {
let session = session.as_logout_session_id();
sqlx::query!(
"UPDATE session SET logged_out = CURRENT_TIMESTAMP WHERE id = $1",
session
)
.execute(sqlx_conn.as_mut())
.await?;
for socket in open_authed_websockets.remove(&session).unwrap_or_default() {
let _ = socket.send(());
}
}
Ok(HasLoggedOutSessions(()))
}
}
/// Used when we need to know that we have logged in with a valid user
#[derive(Clone, Copy)]
pub struct HasValidSession(());
impl HasValidSession {
pub async fn from_request_parts(
request_parts: &RequestParts,
ctx: &RpcContext,
) -> Result<Self, Error> {
if let Some(cookie_header) = request_parts.headers.get(COOKIE) {
let cookies = Cookie::parse(
cookie_header
.to_str()
.with_kind(crate::ErrorKind::Authorization)?,
)
.with_kind(crate::ErrorKind::Authorization)?;
if let Some(cookie) = cookies.iter().find(|c| c.get_name() == "local") {
if let Ok(s) = Self::from_local(cookie).await {
return Ok(s);
}
}
if let Some(cookie) = cookies.iter().find(|c| c.get_name() == "session") {
if let Ok(s) = Self::from_session(&HashSessionToken::from_cookie(cookie), ctx).await
{
return Ok(s);
}
}
}
Err(Error::new(
eyre!("UNAUTHORIZED"),
crate::ErrorKind::Authorization,
))
}
pub async fn from_session(session: &HashSessionToken, ctx: &RpcContext) -> Result<Self, Error> {
let session_hash = session.hashed();
let session = sqlx::query!("UPDATE session SET last_active = CURRENT_TIMESTAMP WHERE id = $1 AND logged_out IS NULL OR logged_out > CURRENT_TIMESTAMP", session_hash)
.execute(ctx.secret_store.acquire().await?.as_mut())
.await?;
if session.rows_affected() == 0 {
return Err(Error::new(
eyre!("UNAUTHORIZED"),
crate::ErrorKind::Authorization,
));
}
Ok(Self(()))
}
pub async fn from_local(local: &Cookie<'_>) -> Result<Self, Error> {
let token = tokio::fs::read_to_string(LOCAL_AUTH_COOKIE_PATH).await?;
if local.get_value() == &*token {
Ok(Self(()))
} else {
Err(Error::new(
eyre!("UNAUTHORIZED"),
crate::ErrorKind::Authorization,
))
}
}
}
/// When we have a need to create a new session,
/// Or when we are using internal valid authenticated service.
#[derive(Debug, Clone)]
pub struct HashSessionToken {
hashed: String,
token: String,
}
impl HashSessionToken {
pub fn new() -> Self {
let token = base32::encode(
base32::Alphabet::RFC4648 { padding: false },
&rand::random::<[u8; 16]>(),
)
.to_lowercase();
let hashed = Self::hash(&token);
Self { hashed, token }
}
pub fn from_cookie(cookie: &Cookie) -> Self {
let token = cookie.get_value().to_owned();
let hashed = Self::hash(&token);
Self { hashed, token }
}
pub fn from_request_parts(request_parts: &RequestParts) -> Result<Self, Error> {
if let Some(cookie_header) = request_parts.headers.get(COOKIE) {
let cookies = Cookie::parse(
cookie_header
.to_str()
.with_kind(crate::ErrorKind::Authorization)?,
)
.with_kind(crate::ErrorKind::Authorization)?;
if let Some(session) = cookies.iter().find(|c| c.get_name() == "session") {
return Ok(Self::from_cookie(session));
}
}
Err(Error::new(
eyre!("UNAUTHORIZED"),
crate::ErrorKind::Authorization,
))
}
pub fn header_value(&self) -> Result<http::HeaderValue, Error> {
http::HeaderValue::from_str(&format!(
"session={}; Path=/; SameSite=Lax; Expires=Fri, 31 Dec 9999 23:59:59 GMT;",
self.token
))
.with_kind(crate::ErrorKind::Unknown)
}
pub fn hashed(&self) -> &str {
self.hashed.as_str()
}
pub fn as_hash(self) -> String {
self.hashed
}
fn hash(token: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
base32::encode(
base32::Alphabet::RFC4648 { padding: false },
hasher.finalize().as_slice(),
)
.to_lowercase()
}
}
impl AsLogoutSessionId for HashSessionToken {
fn as_logout_session_id(self) -> String {
self.hashed
}
}
impl PartialEq for HashSessionToken {
fn eq(&self, other: &Self) -> bool {
self.hashed == other.hashed
}
}
impl Eq for HashSessionToken {}
impl PartialOrd for HashSessionToken {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.hashed.partial_cmp(&other.hashed)
}
}
impl Ord for HashSessionToken {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.hashed.cmp(&other.hashed)
}
}
impl Borrow<String> for HashSessionToken {
fn borrow(&self) -> &String {
&self.hashed
}
}
pub fn auth<M: Metadata>(ctx: RpcContext) -> DynMiddleware<M> {
let rate_limiter = Arc::new(Mutex::new((0_usize, Instant::now())));
Box::new(
move |req: &mut Request<Body>,
metadata: M|
-> BoxFuture<Result<Result<DynMiddlewareStage2, Response<Body>>, HttpError>> {
let ctx = ctx.clone();
let rate_limiter = rate_limiter.clone();
async move {
let mut header_stub = Request::new(Body::empty());
*header_stub.headers_mut() = req.headers().clone();
let m2: DynMiddlewareStage2 = Box::new(move |req, rpc_req| {
async move {
if let Err(e) = HasValidSession::from_request_parts(req, &ctx).await {
if metadata
.get(rpc_req.method.as_str(), "authenticated")
.unwrap_or(true)
{
let (res_parts, _) = Response::new(()).into_parts();
return Ok(Err(to_response(
&req.headers,
res_parts,
Err(e.into()),
|_| StatusCode::OK,
)?));
} else if rpc_req.method.as_str() == "auth.login" {
let guard = rate_limiter.lock().await;
if guard.1.elapsed() < Duration::from_secs(20) {
if guard.0 >= 3 {
let (res_parts, _) = Response::new(()).into_parts();
return Ok(Err(to_response(
&req.headers,
res_parts,
Err(Error::new(
eyre!(
"Please limit login attempts to 3 per 20 seconds."
),
crate::ErrorKind::RateLimited,
)
.into()),
|_| StatusCode::OK,
)?));
}
}
}
}
let m3: DynMiddlewareStage3 = Box::new(move |_, res| {
async move {
let mut guard = rate_limiter.lock().await;
if guard.1.elapsed() < Duration::from_secs(20) {
if res.is_err() {
guard.0 += 1;
}
} else {
guard.0 = 0;
}
guard.1 = Instant::now();
Ok(Ok(noop4()))
}
.boxed()
});
Ok(Ok(m3))
}
.boxed()
});
Ok(Ok(m2))
}
.boxed()
},
)
}

View File

@@ -0,0 +1,61 @@
use futures::FutureExt;
use http::HeaderValue;
use hyper::header::HeaderMap;
use rpc_toolkit::hyper::http::Error as HttpError;
use rpc_toolkit::hyper::{Body, Method, Request, Response};
use rpc_toolkit::rpc_server_helpers::{
DynMiddlewareStage2, DynMiddlewareStage3, DynMiddlewareStage4,
};
use rpc_toolkit::Metadata;
fn get_cors_headers(req: &Request<Body>) -> HeaderMap {
let mut res = HeaderMap::new();
if let Some(origin) = req.headers().get("Origin") {
res.insert("Access-Control-Allow-Origin", origin.clone());
}
if let Some(method) = req.headers().get("Access-Control-Request-Method") {
res.insert("Access-Control-Allow-Methods", method.clone());
}
if let Some(headers) = req.headers().get("Access-Control-Request-Headers") {
res.insert("Access-Control-Allow-Headers", headers.clone());
}
res.insert(
"Access-Control-Allow-Credentials",
HeaderValue::from_static("true"),
);
res
}
pub async fn cors<M: Metadata>(
req: &mut Request<Body>,
_metadata: M,
) -> Result<Result<DynMiddlewareStage2, Response<Body>>, HttpError> {
let headers = get_cors_headers(req);
if req.method() == Method::OPTIONS {
Ok(Err({
let mut res = Response::new(Body::empty());
res.headers_mut().extend(headers.into_iter());
res
}))
} else {
Ok(Ok(Box::new(|_, _| {
async move {
let res: DynMiddlewareStage3 = Box::new(|_, _| {
async move {
let res: DynMiddlewareStage4 = Box::new(|res| {
async move {
res.headers_mut().extend(headers.into_iter());
Ok::<_, HttpError>(())
}
.boxed()
});
Ok::<_, HttpError>(Ok(res))
}
.boxed()
});
Ok::<_, HttpError>(Ok(res))
}
.boxed()
})))
}
}

View File

@@ -0,0 +1,50 @@
use futures::future::BoxFuture;
use futures::FutureExt;
use http::HeaderValue;
use rpc_toolkit::hyper::http::Error as HttpError;
use rpc_toolkit::hyper::{Body, Request, Response};
use rpc_toolkit::rpc_server_helpers::{
noop4, DynMiddleware, DynMiddlewareStage2, DynMiddlewareStage3,
};
use rpc_toolkit::yajrc::RpcMethod;
use rpc_toolkit::Metadata;
use crate::context::RpcContext;
pub fn db<M: Metadata>(ctx: RpcContext) -> DynMiddleware<M> {
Box::new(
move |_: &mut Request<Body>,
metadata: M|
-> BoxFuture<Result<Result<DynMiddlewareStage2, Response<Body>>, HttpError>> {
let ctx = ctx.clone();
async move {
let m2: DynMiddlewareStage2 = Box::new(move |_req, rpc_req| {
async move {
let sync_db = metadata
.get(rpc_req.method.as_str(), "sync_db")
.unwrap_or(false);
let m3: DynMiddlewareStage3 = Box::new(move |res, _| {
async move {
if sync_db {
res.headers.append(
"X-Patch-Sequence",
HeaderValue::from_str(
&ctx.db.sequence().await.to_string(),
)?,
);
}
Ok(Ok(noop4()))
}
.boxed()
});
Ok(Ok(m3))
}
.boxed()
});
Ok(Ok(m2))
}
.boxed()
},
)
}

View File

@@ -0,0 +1,39 @@
use futures::FutureExt;
use rpc_toolkit::hyper::http::Error as HttpError;
use rpc_toolkit::hyper::{Body, Request, Response};
use rpc_toolkit::rpc_server_helpers::{noop4, DynMiddlewareStage2, DynMiddlewareStage3};
use rpc_toolkit::yajrc::RpcMethod;
use rpc_toolkit::Metadata;
use crate::Error;
pub async fn diagnostic<M: Metadata>(
_req: &mut Request<Body>,
_metadata: M,
) -> Result<Result<DynMiddlewareStage2, Response<Body>>, HttpError> {
Ok(Ok(Box::new(|_, rpc_req| {
let method = rpc_req.method.as_str().to_owned();
async move {
let res: DynMiddlewareStage3 = Box::new(|_, rpc_res| {
async move {
if let Err(e) = rpc_res {
if e.code == -32601 {
*e = Error::new(
color_eyre::eyre::eyre!(
"{} is not available on the Diagnostic API",
method
),
crate::ErrorKind::DiagnosticMode,
)
.into();
}
}
Ok(Ok(noop4()))
}
.boxed()
});
Ok::<_, HttpError>(Ok(res))
}
.boxed()
})))
}

View File

@@ -0,0 +1,115 @@
use aes::cipher::{CipherKey, NewCipher, Nonce, StreamCipher};
use aes::Aes256Ctr;
use hmac::Hmac;
use josekit::jwk::Jwk;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use tracing::instrument;
pub fn pbkdf2(password: impl AsRef<[u8]>, salt: impl AsRef<[u8]>) -> CipherKey<Aes256Ctr> {
let mut aeskey = CipherKey::<Aes256Ctr>::default();
pbkdf2::pbkdf2::<Hmac<Sha256>>(
password.as_ref(),
salt.as_ref(),
1000,
aeskey.as_mut_slice(),
)
.unwrap();
aeskey
}
pub fn encrypt_slice(input: impl AsRef<[u8]>, password: impl AsRef<[u8]>) -> Vec<u8> {
let prefix: [u8; 32] = rand::random();
let aeskey = pbkdf2(password.as_ref(), &prefix[16..]);
let ctr = Nonce::<Aes256Ctr>::from_slice(&prefix[..16]);
let mut aes = Aes256Ctr::new(&aeskey, ctr);
let mut res = Vec::with_capacity(32 + input.as_ref().len());
res.extend_from_slice(&prefix[..]);
res.extend_from_slice(input.as_ref());
aes.apply_keystream(&mut res[32..]);
res
}
pub fn decrypt_slice(input: impl AsRef<[u8]>, password: impl AsRef<[u8]>) -> Vec<u8> {
if input.as_ref().len() < 32 {
return Vec::new();
}
let (prefix, rest) = input.as_ref().split_at(32);
let aeskey = pbkdf2(password.as_ref(), &prefix[16..]);
let ctr = Nonce::<Aes256Ctr>::from_slice(&prefix[..16]);
let mut aes = Aes256Ctr::new(&aeskey, ctr);
let mut res = rest.to_vec();
aes.apply_keystream(&mut res);
res
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct EncryptedWire {
encrypted: serde_json::Value,
}
impl EncryptedWire {
#[instrument(skip_all)]
pub fn decrypt(self, current_secret: impl AsRef<Jwk>) -> Option<String> {
let current_secret = current_secret.as_ref();
let decrypter = match josekit::jwe::alg::ecdh_es::EcdhEsJweAlgorithm::EcdhEs
.decrypter_from_jwk(current_secret)
{
Ok(a) => a,
Err(e) => {
tracing::warn!("Could not setup awk");
tracing::debug!("{:?}", e);
return None;
}
};
let encrypted = match serde_json::to_string(&self.encrypted) {
Ok(a) => a,
Err(e) => {
tracing::warn!("Could not deserialize");
tracing::debug!("{:?}", e);
return None;
}
};
let (decoded, _) = match josekit::jwe::deserialize_json(&encrypted, &decrypter) {
Ok(a) => a,
Err(e) => {
tracing::warn!("Could not decrypt");
tracing::debug!("{:?}", e);
return None;
}
};
match String::from_utf8(decoded) {
Ok(a) => Some(a),
Err(e) => {
tracing::warn!("Could not decrypt into utf8");
tracing::debug!("{:?}", e);
return None;
}
}
}
}
/// We created this test by first making the private key, then restoring from this private key for recreatability.
/// After this the frontend then encoded an password, then we are testing that the output that we got (hand coded)
/// will be the shape we want.
#[test]
fn test_gen_awk() {
let private_key: Jwk = serde_json::from_str(
r#"{
"kty": "EC",
"crv": "P-256",
"d": "3P-MxbUJtEhdGGpBCRFXkUneGgdyz_DGZWfIAGSCHOU",
"x": "yHTDYSfjU809fkSv9MmN4wuojf5c3cnD7ZDN13n-jz4",
"y": "8Mpkn744A5KDag0DmX2YivB63srjbugYZzWc3JOpQXI"
}"#,
)
.unwrap();
let encrypted: EncryptedWire = serde_json::from_str(r#"{
"encrypted": { "protected": "eyJlbmMiOiJBMTI4Q0JDLUhTMjU2IiwiYWxnIjoiRUNESC1FUyIsImtpZCI6ImgtZnNXUVh2Tm95dmJEazM5dUNsQ0NUdWc5N3MyZnJockJnWUVBUWVtclUiLCJlcGsiOnsia3R5IjoiRUMiLCJjcnYiOiJQLTI1NiIsIngiOiJmRkF0LXNWYWU2aGNkdWZJeUlmVVdUd3ZvWExaTkdKRHZIWVhIckxwOXNNIiwieSI6IjFvVFN6b00teHlFZC1SLUlBaUFHdXgzS1dJZmNYZHRMQ0JHLUh6MVkzY2sifX0", "iv": "NbwvfvWOdLpZfYRIZUrkcw", "ciphertext": "Zc5Br5kYOlhPkIjQKOLMJw", "tag": "EPoch52lDuCsbUUulzZGfg" }
}"#).unwrap();
assert_eq!(
"testing12345",
&encrypted.decrypt(std::sync::Arc::new(private_key)).unwrap()
);
}

View File

@@ -0,0 +1,5 @@
pub mod auth;
pub mod cors;
pub mod db;
pub mod diagnostic;
pub mod encrypt;

View File

@@ -0,0 +1,7 @@
load database
from sqlite://{sqlite_path}
into postgresql://root@unix:/var/run/postgresql:5432/secrets
with include no drop, truncate, reset sequences, data only, workers = 1, concurrency = 1, max parallel create index = 1, batch rows = {batch_rows}, prefetch rows = {prefetch_rows}
excluding table names like '_sqlx_migrations', 'notifications';

View File

@@ -0,0 +1,141 @@
use std::collections::BTreeSet;
use color_eyre::eyre::eyre;
use emver::VersionRange;
use futures::{Future, FutureExt};
use indexmap::IndexMap;
use models::ImageId;
use patch_db::HasModel;
use serde::{Deserialize, Serialize};
use tracing::instrument;
use crate::context::RpcContext;
use crate::prelude::*;
use crate::procedure::docker::DockerContainers;
use crate::procedure::{PackageProcedure, ProcedureName};
use crate::s9pk::manifest::PackageId;
use crate::util::Version;
use crate::volume::Volumes;
use crate::{Error, ResultExt};
#[derive(Clone, Debug, Default, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct Migrations {
pub from: IndexMap<VersionRange, PackageProcedure>,
pub to: IndexMap<VersionRange, PackageProcedure>,
}
impl Migrations {
#[instrument(skip_all)]
pub fn validate(
&self,
_container: &Option<DockerContainers>,
eos_version: &Version,
volumes: &Volumes,
image_ids: &BTreeSet<ImageId>,
) -> Result<(), Error> {
for (version, migration) in &self.from {
migration
.validate(eos_version, volumes, image_ids, true)
.with_ctx(|_| {
(
crate::ErrorKind::ValidateS9pk,
format!("Migration from {}", version),
)
})?;
}
for (version, migration) in &self.to {
migration
.validate(eos_version, volumes, image_ids, true)
.with_ctx(|_| {
(
crate::ErrorKind::ValidateS9pk,
format!("Migration to {}", version),
)
})?;
}
Ok(())
}
#[instrument(skip_all)]
pub fn from<'a>(
&'a self,
_container: &'a Option<DockerContainers>,
ctx: &'a RpcContext,
version: &'a Version,
pkg_id: &'a PackageId,
pkg_version: &'a Version,
volumes: &'a Volumes,
) -> Option<impl Future<Output = Result<MigrationRes, Error>> + 'a> {
if let Some((_, migration)) = self
.from
.iter()
.find(|(range, _)| version.satisfies(*range))
{
Some(async move {
migration
.execute(
ctx,
pkg_id,
pkg_version,
ProcedureName::Migration, // Migrations cannot be executed concurrently
volumes,
Some(version),
None,
)
.map(|r| {
r.and_then(|r| {
r.map_err(|e| {
Error::new(eyre!("{}", e.1), crate::ErrorKind::MigrationFailed)
})
})
})
.await
})
} else {
None
}
}
#[instrument(skip_all)]
pub fn to<'a>(
&'a self,
ctx: &'a RpcContext,
version: &'a Version,
pkg_id: &'a PackageId,
pkg_version: &'a Version,
volumes: &'a Volumes,
) -> Option<impl Future<Output = Result<MigrationRes, Error>> + 'a> {
if let Some((_, migration)) = self.to.iter().find(|(range, _)| version.satisfies(*range)) {
Some(async move {
migration
.execute(
ctx,
pkg_id,
pkg_version,
ProcedureName::Migration,
volumes,
Some(version),
None,
)
.map(|r| {
r.and_then(|r| {
r.map_err(|e| {
Error::new(eyre!("{}", e.1), crate::ErrorKind::MigrationFailed)
})
})
})
.await
})
} else {
None
}
}
}
#[derive(Clone, Debug, Default, Deserialize, Serialize, HasModel)]
#[serde(rename_all = "kebab-case")]
#[model = "Model<Self>"]
pub struct MigrationRes {
pub configured: bool,
}

View File

@@ -0,0 +1,10 @@
[req]
default_bits = 4096
default_md = sha256
distinguished_name = req_distinguished_name
prompt = no
[req_distinguished_name]
CN = {hostname}.local
O = Start9 Labs
OU = StartOS

View File

@@ -0,0 +1,82 @@
use std::collections::{BTreeMap, BTreeSet};
use std::net::IpAddr;
use futures::TryStreamExt;
use rpc_toolkit::command;
use tokio::sync::RwLock;
use crate::context::RpcContext;
use crate::db::model::IpInfo;
use crate::net::utils::{iface_is_physical, list_interfaces};
use crate::prelude::*;
use crate::util::display_none;
use crate::Error;
lazy_static::lazy_static! {
static ref CACHED_IPS: RwLock<BTreeSet<IpAddr>> = RwLock::new(BTreeSet::new());
}
async fn _ips() -> Result<BTreeSet<IpAddr>, Error> {
Ok(init_ips()
.await?
.values()
.flat_map(|i| {
std::iter::empty()
.chain(i.ipv4.map(IpAddr::from))
.chain(i.ipv6.map(IpAddr::from))
})
.collect())
}
pub async fn ips() -> Result<BTreeSet<IpAddr>, Error> {
let ips = CACHED_IPS.read().await.clone();
if !ips.is_empty() {
return Ok(ips);
}
let ips = _ips().await?;
*CACHED_IPS.write().await = ips.clone();
Ok(ips)
}
pub async fn init_ips() -> Result<BTreeMap<String, IpInfo>, Error> {
let mut res = BTreeMap::new();
let mut ifaces = list_interfaces();
while let Some(iface) = ifaces.try_next().await? {
if iface_is_physical(&iface).await {
let ip_info = IpInfo::for_interface(&iface).await?;
res.insert(iface, ip_info);
}
}
Ok(res)
}
#[command(subcommands(update))]
pub async fn dhcp() -> Result<(), Error> {
Ok(())
}
#[command(display(display_none))]
pub async fn update(#[context] ctx: RpcContext, #[arg] interface: String) -> Result<(), Error> {
if iface_is_physical(&interface).await {
let ip_info = IpInfo::for_interface(&interface).await?;
ctx.db
.mutate(|db| {
db.as_server_info_mut()
.as_ip_info_mut()
.insert(&interface, &ip_info)
})
.await?;
let mut cached = CACHED_IPS.write().await;
if cached.is_empty() {
*cached = _ips().await?;
} else {
cached.extend(
std::iter::empty()
.chain(ip_info.ipv4.map(IpAddr::from))
.chain(ip_info.ipv6.map(IpAddr::from)),
);
}
}
Ok(())
}

228
core/startos/src/net/dns.rs Normal file
View File

@@ -0,0 +1,228 @@
use std::borrow::Borrow;
use std::collections::BTreeMap;
use std::net::{Ipv4Addr, SocketAddr};
use std::sync::{Arc, Weak};
use std::time::Duration;
use color_eyre::eyre::eyre;
use futures::TryFutureExt;
use helpers::NonDetachingJoinHandle;
use models::PackageId;
use tokio::net::{TcpListener, UdpSocket};
use tokio::process::Command;
use tokio::sync::RwLock;
use tracing::instrument;
use trust_dns_server::authority::MessageResponseBuilder;
use trust_dns_server::proto::op::{Header, ResponseCode};
use trust_dns_server::proto::rr::{Name, Record, RecordType};
use trust_dns_server::server::{Request, RequestHandler, ResponseHandler, ResponseInfo};
use trust_dns_server::ServerFuture;
use crate::util::Invoke;
use crate::{Error, ErrorKind, ResultExt};
pub struct DnsController {
services: Weak<RwLock<BTreeMap<Option<PackageId>, BTreeMap<Ipv4Addr, Weak<()>>>>>,
#[allow(dead_code)]
dns_server: NonDetachingJoinHandle<Result<(), Error>>,
}
struct Resolver {
services: Arc<RwLock<BTreeMap<Option<PackageId>, BTreeMap<Ipv4Addr, Weak<()>>>>>,
}
impl Resolver {
async fn resolve(&self, name: &Name) -> Option<Vec<Ipv4Addr>> {
match name.iter().next_back() {
Some(b"embassy") => {
if let Some(pkg) = name.iter().rev().skip(1).next() {
if let Some(ip) = self.services.read().await.get(&Some(
std::str::from_utf8(pkg)
.unwrap_or_default()
.parse()
.unwrap_or_default(),
)) {
Some(
ip.iter()
.filter(|(_, rc)| rc.strong_count() > 0)
.map(|(ip, _)| *ip)
.collect(),
)
} else {
None
}
} else if let Some(ip) = self.services.read().await.get(&None) {
Some(
ip.iter()
.filter(|(_, rc)| rc.strong_count() > 0)
.map(|(ip, _)| *ip)
.collect(),
)
} else {
None
}
}
_ => None,
}
}
}
#[async_trait::async_trait]
impl RequestHandler for Resolver {
async fn handle_request<R: ResponseHandler>(
&self,
request: &Request,
mut response_handle: R,
) -> ResponseInfo {
let query = request.request_info().query;
if let Some(ip) = self.resolve(query.name().borrow()).await {
match query.query_type() {
RecordType::A => {
response_handle
.send_response(
MessageResponseBuilder::from_message_request(&*request).build(
Header::response_from_request(request.header()),
&ip.into_iter()
.map(|ip| {
Record::from_rdata(
request.request_info().query.name().to_owned().into(),
0,
trust_dns_server::proto::rr::RData::A(ip.into()),
)
})
.collect::<Vec<_>>(),
[],
[],
[],
),
)
.await
}
a => {
if a != RecordType::AAAA {
tracing::warn!(
"Non A-Record requested for {}: {:?}",
query.name(),
query.query_type()
);
}
let mut res = Header::response_from_request(request.header());
res.set_response_code(ResponseCode::NXDomain);
response_handle
.send_response(
MessageResponseBuilder::from_message_request(&*request).build(
res.into(),
[],
[],
[],
[],
),
)
.await
}
}
} else {
let mut res = Header::response_from_request(request.header());
res.set_response_code(ResponseCode::NXDomain);
response_handle
.send_response(
MessageResponseBuilder::from_message_request(&*request).build(
res.into(),
[],
[],
[],
[],
),
)
.await
}
.unwrap_or_else(|e| {
tracing::error!("{}", e);
tracing::debug!("{:?}", e);
let mut res = Header::response_from_request(request.header());
res.set_response_code(ResponseCode::ServFail);
res.into()
})
}
}
impl DnsController {
#[instrument(skip_all)]
pub async fn init(bind: &[SocketAddr]) -> Result<Self, Error> {
let services = Arc::new(RwLock::new(BTreeMap::new()));
let mut server = ServerFuture::new(Resolver {
services: services.clone(),
});
server.register_listener(
TcpListener::bind(bind)
.await
.with_kind(ErrorKind::Network)?,
Duration::from_secs(30),
);
server.register_socket(UdpSocket::bind(bind).await.with_kind(ErrorKind::Network)?);
Command::new("resolvectl")
.arg("dns")
.arg("br-start9")
.arg("127.0.0.1")
.invoke(ErrorKind::Network)
.await?;
Command::new("resolvectl")
.arg("domain")
.arg("br-start9")
.arg("embassy")
.invoke(ErrorKind::Network)
.await?;
let dns_server = tokio::spawn(
server
.block_until_done()
.map_err(|e| Error::new(e, ErrorKind::Network)),
)
.into();
Ok(Self {
services: Arc::downgrade(&services),
dns_server,
})
}
pub async fn add(&self, pkg_id: Option<PackageId>, ip: Ipv4Addr) -> Result<Arc<()>, Error> {
if let Some(services) = Weak::upgrade(&self.services) {
let mut writable = services.write().await;
let mut ips = writable.remove(&pkg_id).unwrap_or_default();
let rc = if let Some(rc) = Weak::upgrade(&ips.remove(&ip).unwrap_or_default()) {
rc
} else {
Arc::new(())
};
ips.insert(ip, Arc::downgrade(&rc));
writable.insert(pkg_id, ips);
Ok(rc)
} else {
Err(Error::new(
eyre!("DNS Server Thread has exited"),
crate::ErrorKind::Network,
))
}
}
pub async fn gc(&self, pkg_id: Option<PackageId>, ip: Ipv4Addr) -> Result<(), Error> {
if let Some(services) = Weak::upgrade(&self.services) {
let mut writable = services.write().await;
let mut ips = writable.remove(&pkg_id).unwrap_or_default();
if let Some(rc) = Weak::upgrade(&ips.remove(&ip).unwrap_or_default()) {
ips.insert(ip, Arc::downgrade(&rc));
}
if !ips.is_empty() {
writable.insert(pkg_id, ips);
}
Ok(())
} else {
Err(Error::new(
eyre!("DNS Server Thread has exited"),
crate::ErrorKind::Network,
))
}
}
}

View File

@@ -0,0 +1,122 @@
use std::collections::BTreeMap;
use indexmap::IndexSet;
pub use models::InterfaceId;
use serde::{Deserialize, Deserializer, Serialize};
use sqlx::{Executor, Postgres};
use tracing::instrument;
use crate::db::model::{InterfaceAddressMap, InterfaceAddresses};
use crate::net::keys::Key;
use crate::s9pk::manifest::PackageId;
use crate::util::serde::Port;
use crate::{Error, ResultExt};
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct Interfaces(pub BTreeMap<InterfaceId, Interface>); // TODO
impl Interfaces {
#[instrument(skip_all)]
pub fn validate(&self) -> Result<(), Error> {
for (_, interface) in &self.0 {
interface.validate().with_ctx(|_| {
(
crate::ErrorKind::ValidateS9pk,
format!("Interface {}", interface.name),
)
})?;
}
Ok(())
}
#[instrument(skip_all)]
pub async fn install<Ex>(
&self,
secrets: &mut Ex,
package_id: &PackageId,
) -> Result<InterfaceAddressMap, Error>
where
for<'a> &'a mut Ex: Executor<'a, Database = Postgres>,
{
let mut interface_addresses = InterfaceAddressMap(BTreeMap::new());
for (id, iface) in &self.0 {
let mut addrs = InterfaceAddresses {
tor_address: None,
lan_address: None,
};
if iface.tor_config.is_some() || iface.lan_config.is_some() {
let key =
Key::for_interface(secrets, Some((package_id.clone(), id.clone()))).await?;
if iface.tor_config.is_some() {
addrs.tor_address = Some(key.tor_address().to_string());
}
if iface.lan_config.is_some() {
addrs.lan_address = Some(key.local_address());
}
}
interface_addresses.0.insert(id.clone(), addrs);
}
Ok(interface_addresses)
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct Interface {
pub name: String,
pub description: String,
pub tor_config: Option<TorConfig>,
pub lan_config: Option<BTreeMap<Port, LanPortConfig>>,
pub ui: bool,
pub protocols: IndexSet<String>,
}
impl Interface {
#[instrument(skip_all)]
pub fn validate(&self) -> Result<(), color_eyre::eyre::Report> {
if self.tor_config.is_some() && !self.protocols.contains("tcp") {
color_eyre::eyre::bail!("must support tcp to set up a tor hidden service");
}
if self.lan_config.is_some() && !self.protocols.contains("http") {
color_eyre::eyre::bail!("must support http to set up a lan service");
}
if self.ui && !(self.protocols.contains("http") || self.protocols.contains("https")) {
color_eyre::eyre::bail!("must support http or https to serve a ui");
}
Ok(())
}
}
#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct TorConfig {
pub port_mapping: BTreeMap<Port, Port>,
}
#[derive(Clone, Debug, Serialize)]
#[serde(rename_all = "kebab-case")]
pub struct LanPortConfig {
pub ssl: bool,
pub internal: u16,
}
impl<'de> Deserialize<'de> for LanPortConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(rename_all = "kebab-case")]
struct PermissiveLanPortConfig {
ssl: bool,
internal: Option<u16>,
mapping: Option<u16>,
}
let config = PermissiveLanPortConfig::deserialize(deserializer)?;
Ok(LanPortConfig {
ssl: config.ssl,
internal: config
.internal
.or(config.mapping)
.ok_or_else(|| serde::de::Error::missing_field("internal"))?,
})
}
}

View File

@@ -0,0 +1,273 @@
use color_eyre::eyre::eyre;
use models::{Id, InterfaceId, PackageId};
use openssl::pkey::{PKey, Private};
use openssl::sha::Sha256;
use openssl::x509::X509;
use p256::elliptic_curve::pkcs8::EncodePrivateKey;
use sqlx::PgExecutor;
use ssh_key::private::Ed25519PrivateKey;
use torut::onion::{OnionAddressV3, TorSecretKeyV3};
use zeroize::Zeroize;
use crate::net::ssl::CertPair;
use crate::prelude::*;
use crate::util::crypto::ed25519_expand_key;
// TODO: delete once we may change tor addresses
async fn compat(
secrets: impl PgExecutor<'_>,
interface: &Option<(PackageId, InterfaceId)>,
) -> Result<Option<[u8; 64]>, Error> {
if let Some((package, interface)) = interface {
if let Some(r) = sqlx::query!(
"SELECT key FROM tor WHERE package = $1 AND interface = $2",
package,
interface
)
.fetch_optional(secrets)
.await?
{
Ok(Some(<[u8; 64]>::try_from(r.key).map_err(|e| {
Error::new(
eyre!("expected vec of len 64, got len {}", e.len()),
ErrorKind::ParseDbField,
)
})?))
} else {
Ok(None)
}
} else if let Some(key) = sqlx::query!("SELECT tor_key FROM account WHERE id = 0")
.fetch_one(secrets)
.await?
.tor_key
{
Ok(Some(<[u8; 64]>::try_from(key).map_err(|e| {
Error::new(
eyre!("expected vec of len 64, got len {}", e.len()),
ErrorKind::ParseDbField,
)
})?))
} else {
Ok(None)
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Key {
interface: Option<(PackageId, InterfaceId)>,
base: [u8; 32],
tor_key: [u8; 64], // Does NOT necessarily match base
}
impl Key {
pub fn interface(&self) -> Option<(PackageId, InterfaceId)> {
self.interface.clone()
}
pub fn as_bytes(&self) -> [u8; 32] {
self.base
}
pub fn internal_address(&self) -> String {
self.interface
.as_ref()
.map(|(pkg_id, _)| format!("{}.embassy", pkg_id))
.unwrap_or_else(|| "embassy".to_owned())
}
pub fn tor_key(&self) -> TorSecretKeyV3 {
self.tor_key.into()
}
pub fn tor_address(&self) -> OnionAddressV3 {
self.tor_key().public().get_onion_address()
}
pub fn base_address(&self) -> String {
self.tor_key()
.public()
.get_onion_address()
.get_address_without_dot_onion()
}
pub fn local_address(&self) -> String {
self.base_address() + ".local"
}
pub fn openssl_key_ed25519(&self) -> PKey<Private> {
PKey::private_key_from_raw_bytes(&self.base, openssl::pkey::Id::ED25519).unwrap()
}
pub fn openssl_key_nistp256(&self) -> PKey<Private> {
let mut buf = self.base;
loop {
if let Ok(k) = p256::SecretKey::from_slice(&buf) {
return PKey::private_key_from_pkcs8(&*k.to_pkcs8_der().unwrap().as_bytes())
.unwrap();
}
let mut sha = Sha256::new();
sha.update(&buf);
buf = sha.finish();
}
}
pub fn ssh_key(&self) -> Ed25519PrivateKey {
Ed25519PrivateKey::from_bytes(&self.base)
}
pub(crate) fn from_pair(
interface: Option<(PackageId, InterfaceId)>,
bytes: [u8; 32],
tor_key: [u8; 64],
) -> Self {
Self {
interface,
tor_key,
base: bytes,
}
}
pub fn from_bytes(interface: Option<(PackageId, InterfaceId)>, bytes: [u8; 32]) -> Self {
Self::from_pair(interface, bytes, ed25519_expand_key(&bytes))
}
pub fn new(interface: Option<(PackageId, InterfaceId)>) -> Self {
Self::from_bytes(interface, rand::random())
}
pub(super) fn with_certs(self, certs: CertPair, int: X509, root: X509) -> KeyInfo {
KeyInfo {
key: self,
certs,
int,
root,
}
}
pub async fn for_package(
secrets: impl PgExecutor<'_>,
package: &PackageId,
) -> Result<Vec<Self>, Error> {
sqlx::query!(
r#"
SELECT
network_keys.package,
network_keys.interface,
network_keys.key,
tor.key AS "tor_key?"
FROM
network_keys
LEFT JOIN
tor
ON
network_keys.package = tor.package
AND
network_keys.interface = tor.interface
WHERE
network_keys.package = $1
"#,
package
)
.fetch_all(secrets)
.await?
.into_iter()
.map(|row| {
let interface = Some((
package.clone(),
InterfaceId::from(Id::try_from(row.interface)?),
));
let bytes = row.key.try_into().map_err(|e: Vec<u8>| {
Error::new(
eyre!("Invalid length for network key {} expected 32", e.len()),
crate::ErrorKind::Database,
)
})?;
Ok(match row.tor_key {
Some(tor_key) => Key::from_pair(
interface,
bytes,
tor_key.try_into().map_err(|e: Vec<u8>| {
Error::new(
eyre!("Invalid length for tor key {} expected 64", e.len()),
crate::ErrorKind::Database,
)
})?,
),
None => Key::from_bytes(interface, bytes),
})
})
.collect()
}
pub async fn for_interface<Ex>(
secrets: &mut Ex,
interface: Option<(PackageId, InterfaceId)>,
) -> Result<Self, Error>
where
for<'a> &'a mut Ex: PgExecutor<'a>,
{
let tentative = rand::random::<[u8; 32]>();
let actual = if let Some((pkg, iface)) = &interface {
let k = tentative.as_slice();
let actual = sqlx::query!(
"INSERT INTO network_keys (package, interface, key) VALUES ($1, $2, $3) ON CONFLICT (package, interface) DO UPDATE SET package = EXCLUDED.package RETURNING key",
pkg,
iface,
k,
)
.fetch_one(&mut *secrets)
.await?.key;
let mut bytes = tentative;
bytes.clone_from_slice(actual.get(0..32).ok_or_else(|| {
Error::new(
eyre!("Invalid key size returned from DB"),
crate::ErrorKind::Database,
)
})?);
bytes
} else {
let actual = sqlx::query!("SELECT network_key FROM account WHERE id = 0")
.fetch_one(&mut *secrets)
.await?
.network_key;
let mut bytes = tentative;
bytes.clone_from_slice(actual.get(0..32).ok_or_else(|| {
Error::new(
eyre!("Invalid key size returned from DB"),
crate::ErrorKind::Database,
)
})?);
bytes
};
let mut res = Self::from_bytes(interface, actual);
if let Some(tor_key) = compat(secrets, &res.interface).await? {
res.tor_key = tor_key;
}
Ok(res)
}
}
impl Drop for Key {
fn drop(&mut self) {
self.base.zeroize();
self.tor_key.zeroize();
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct KeyInfo {
key: Key,
certs: CertPair,
int: X509,
root: X509,
}
impl KeyInfo {
pub fn key(&self) -> &Key {
&self.key
}
pub fn certs(&self) -> &CertPair {
&self.certs
}
pub fn int_ca(&self) -> &X509 {
&self.int
}
pub fn root_ca(&self) -> &X509 {
&self.root
}
pub fn fullchain_ed25519(&self) -> Vec<&X509> {
vec![&self.certs.ed25519, &self.int, &self.root]
}
pub fn fullchain_nistp256(&self) -> Vec<&X509> {
vec![&self.certs.nistp256, &self.int, &self.root]
}
}
#[test]
pub fn test_keygen() {
let key = Key::new(None);
key.tor_key();
key.openssl_key_nistp256();
}

View File

@@ -0,0 +1,100 @@
use std::collections::BTreeMap;
use std::net::Ipv4Addr;
use std::sync::{Arc, Weak};
use color_eyre::eyre::eyre;
use tokio::process::{Child, Command};
use tokio::sync::Mutex;
use tracing::instrument;
use crate::util::Invoke;
use crate::{Error, ResultExt};
pub async fn resolve_mdns(hostname: &str) -> Result<Ipv4Addr, Error> {
Ok(String::from_utf8(
Command::new("avahi-resolve-host-name")
.kill_on_drop(true)
.arg("-4")
.arg(hostname)
.invoke(crate::ErrorKind::Network)
.await?,
)?
.split_once("\t")
.ok_or_else(|| {
Error::new(
eyre!("Failed to resolve hostname: {}", hostname),
crate::ErrorKind::Network,
)
})?
.1
.trim()
.parse()?)
}
pub struct MdnsController(Mutex<MdnsControllerInner>);
impl MdnsController {
pub async fn init() -> Result<Self, Error> {
Ok(MdnsController(Mutex::new(
MdnsControllerInner::init().await?,
)))
}
pub async fn add(&self, alias: String) -> Result<Arc<()>, Error> {
self.0.lock().await.add(alias).await
}
pub async fn gc(&self, alias: String) -> Result<(), Error> {
self.0.lock().await.gc(alias).await
}
}
pub struct MdnsControllerInner {
alias_cmd: Option<Child>,
services: BTreeMap<String, Weak<()>>,
}
impl MdnsControllerInner {
#[instrument(skip_all)]
async fn init() -> Result<Self, Error> {
let mut res = MdnsControllerInner {
alias_cmd: None,
services: BTreeMap::new(),
};
res.sync().await?;
Ok(res)
}
#[instrument(skip_all)]
async fn sync(&mut self) -> Result<(), Error> {
if let Some(mut cmd) = self.alias_cmd.take() {
cmd.kill().await.with_kind(crate::ErrorKind::Network)?;
}
self.alias_cmd = Some(
Command::new("avahi-alias")
.kill_on_drop(true)
.args(
self.services
.iter()
.filter(|(_, rc)| rc.strong_count() > 0)
.map(|(s, _)| s),
)
.spawn()?,
);
Ok(())
}
async fn add(&mut self, alias: String) -> Result<Arc<()>, Error> {
let rc = if let Some(rc) = Weak::upgrade(&self.services.remove(&alias).unwrap_or_default())
{
rc
} else {
Arc::new(())
};
self.services.insert(alias, Arc::downgrade(&rc));
self.sync().await?;
Ok(rc)
}
async fn gc(&mut self, alias: String) -> Result<(), Error> {
if let Some(rc) = Weak::upgrade(&self.services.remove(&alias).unwrap_or_default()) {
self.services.insert(alias, Arc::downgrade(&rc));
}
self.sync().await?;
Ok(())
}
}

View File

@@ -0,0 +1,32 @@
use std::sync::Arc;
use futures::future::BoxFuture;
use hyper::{Body, Error as HyperError, Request, Response};
use rpc_toolkit::command;
use crate::Error;
pub mod dhcp;
pub mod dns;
pub mod interface;
pub mod keys;
pub mod mdns;
pub mod net_controller;
pub mod ssl;
pub mod static_server;
pub mod tor;
pub mod utils;
pub mod vhost;
pub mod web_server;
pub mod wifi;
pub const PACKAGE_CERT_PATH: &str = "/var/lib/embassy/ssl";
#[command(subcommands(tor::tor, dhcp::dhcp, ssl::ssl))]
pub fn net() -> Result<(), Error> {
Ok(())
}
pub type HttpHandler = Arc<
dyn Fn(Request<Body>) -> BoxFuture<'static, Result<Response<Body>, HyperError>> + Send + Sync,
>;

View File

@@ -0,0 +1,369 @@
use std::collections::BTreeMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::{Arc, Weak};
use color_eyre::eyre::eyre;
use models::InterfaceId;
use sqlx::PgExecutor;
use tracing::instrument;
use crate::error::ErrorCollection;
use crate::hostname::Hostname;
use crate::net::dns::DnsController;
use crate::net::keys::Key;
use crate::net::mdns::MdnsController;
use crate::net::ssl::{export_cert, export_key, SslManager};
use crate::net::tor::TorController;
use crate::net::vhost::{AlpnInfo, VHostController};
use crate::s9pk::manifest::PackageId;
use crate::volume::cert_dir;
use crate::{Error, HOST_IP};
pub struct NetController {
pub(super) tor: TorController,
pub(super) mdns: MdnsController,
pub(super) vhost: VHostController,
pub(super) dns: DnsController,
pub(super) ssl: Arc<SslManager>,
pub(super) os_bindings: Vec<Arc<()>>,
}
impl NetController {
#[instrument(skip_all)]
pub async fn init(
tor_control: SocketAddr,
tor_socks: SocketAddr,
dns_bind: &[SocketAddr],
ssl: SslManager,
hostname: &Hostname,
os_key: &Key,
) -> Result<Self, Error> {
let ssl = Arc::new(ssl);
let mut res = Self {
tor: TorController::new(tor_control, tor_socks),
mdns: MdnsController::init().await?,
vhost: VHostController::new(ssl.clone()),
dns: DnsController::init(dns_bind).await?,
ssl,
os_bindings: Vec::new(),
};
res.add_os_bindings(hostname, os_key).await?;
Ok(res)
}
async fn add_os_bindings(&mut self, hostname: &Hostname, key: &Key) -> Result<(), Error> {
let alpn = Err(AlpnInfo::Specified(vec!["http/1.1".into(), "h2".into()]));
// Internal DNS
self.vhost
.add(
key.clone(),
Some("embassy".into()),
443,
([127, 0, 0, 1], 80).into(),
alpn.clone(),
)
.await?;
self.os_bindings
.push(self.dns.add(None, HOST_IP.into()).await?);
// LAN IP
self.os_bindings.push(
self.vhost
.add(
key.clone(),
None,
443,
([127, 0, 0, 1], 80).into(),
alpn.clone(),
)
.await?,
);
// localhost
self.os_bindings.push(
self.vhost
.add(
key.clone(),
Some("localhost".into()),
443,
([127, 0, 0, 1], 80).into(),
alpn.clone(),
)
.await?,
);
self.os_bindings.push(
self.vhost
.add(
key.clone(),
Some(hostname.no_dot_host_name()),
443,
([127, 0, 0, 1], 80).into(),
alpn.clone(),
)
.await?,
);
// LAN mDNS
self.os_bindings.push(
self.vhost
.add(
key.clone(),
Some(hostname.local_domain_name()),
443,
([127, 0, 0, 1], 80).into(),
alpn.clone(),
)
.await?,
);
// Tor (http)
self.os_bindings.push(
self.tor
.add(key.tor_key(), 80, ([127, 0, 0, 1], 80).into())
.await?,
);
// Tor (https)
self.os_bindings.push(
self.vhost
.add(
key.clone(),
Some(key.tor_address().to_string()),
443,
([127, 0, 0, 1], 80).into(),
alpn.clone(),
)
.await?,
);
self.os_bindings.push(
self.tor
.add(key.tor_key(), 443, ([127, 0, 0, 1], 443).into())
.await?,
);
Ok(())
}
#[instrument(skip_all)]
pub async fn create_service(
self: &Arc<Self>,
package: PackageId,
ip: Ipv4Addr,
) -> Result<NetService, Error> {
let dns = self.dns.add(Some(package.clone()), ip).await?;
Ok(NetService {
shutdown: false,
id: package,
ip,
dns,
controller: Arc::downgrade(self),
tor: BTreeMap::new(),
lan: BTreeMap::new(),
})
}
async fn add_tor(
&self,
key: &Key,
external: u16,
target: SocketAddr,
) -> Result<Vec<Arc<()>>, Error> {
let mut rcs = Vec::with_capacity(1);
rcs.push(self.tor.add(key.tor_key(), external, target).await?);
Ok(rcs)
}
async fn remove_tor(&self, key: &Key, external: u16, rcs: Vec<Arc<()>>) -> Result<(), Error> {
drop(rcs);
self.tor.gc(Some(key.tor_key()), Some(external)).await
}
async fn add_lan(
&self,
key: Key,
external: u16,
target: SocketAddr,
connect_ssl: Result<(), AlpnInfo>,
) -> Result<Vec<Arc<()>>, Error> {
let mut rcs = Vec::with_capacity(2);
rcs.push(
self.vhost
.add(
key.clone(),
Some(key.local_address()),
external,
target.into(),
connect_ssl,
)
.await?,
);
rcs.push(self.mdns.add(key.base_address()).await?);
Ok(rcs)
}
async fn remove_lan(&self, key: &Key, external: u16, rcs: Vec<Arc<()>>) -> Result<(), Error> {
drop(rcs);
self.mdns.gc(key.base_address()).await?;
self.vhost.gc(Some(key.local_address()), external).await
}
}
pub struct NetService {
shutdown: bool,
id: PackageId,
ip: Ipv4Addr,
dns: Arc<()>,
controller: Weak<NetController>,
tor: BTreeMap<(InterfaceId, u16), (Key, Vec<Arc<()>>)>,
lan: BTreeMap<(InterfaceId, u16), (Key, Vec<Arc<()>>)>,
}
impl NetService {
fn net_controller(&self) -> Result<Arc<NetController>, Error> {
Weak::upgrade(&self.controller).ok_or_else(|| {
Error::new(
eyre!("NetController is shutdown"),
crate::ErrorKind::Network,
)
})
}
pub async fn add_tor<Ex>(
&mut self,
secrets: &mut Ex,
id: InterfaceId,
external: u16,
internal: u16,
) -> Result<(), Error>
where
for<'a> &'a mut Ex: PgExecutor<'a>,
{
let key = Key::for_interface(secrets, Some((self.id.clone(), id.clone()))).await?;
let ctrl = self.net_controller()?;
let tor_idx = (id, external);
let mut tor = self
.tor
.remove(&tor_idx)
.unwrap_or_else(|| (key.clone(), Vec::new()));
tor.1.append(
&mut ctrl
.add_tor(&key, external, SocketAddr::new(self.ip.into(), internal))
.await?,
);
self.tor.insert(tor_idx, tor);
Ok(())
}
pub async fn remove_tor(&mut self, id: InterfaceId, external: u16) -> Result<(), Error> {
let ctrl = self.net_controller()?;
if let Some((key, rcs)) = self.tor.remove(&(id, external)) {
ctrl.remove_tor(&key, external, rcs).await?;
}
Ok(())
}
pub async fn add_lan<Ex>(
&mut self,
secrets: &mut Ex,
id: InterfaceId,
external: u16,
internal: u16,
connect_ssl: Result<(), AlpnInfo>,
) -> Result<(), Error>
where
for<'a> &'a mut Ex: PgExecutor<'a>,
{
let key = Key::for_interface(secrets, Some((self.id.clone(), id.clone()))).await?;
let ctrl = self.net_controller()?;
let lan_idx = (id, external);
let mut lan = self
.lan
.remove(&lan_idx)
.unwrap_or_else(|| (key.clone(), Vec::new()));
lan.1.append(
&mut ctrl
.add_lan(
key,
external,
SocketAddr::new(self.ip.into(), internal),
connect_ssl,
)
.await?,
);
self.lan.insert(lan_idx, lan);
Ok(())
}
pub async fn remove_lan(&mut self, id: InterfaceId, external: u16) -> Result<(), Error> {
let ctrl = self.net_controller()?;
if let Some((key, rcs)) = self.lan.remove(&(id, external)) {
ctrl.remove_lan(&key, external, rcs).await?;
}
Ok(())
}
pub async fn export_cert<Ex>(
&self,
secrets: &mut Ex,
id: &InterfaceId,
ip: IpAddr,
) -> Result<(), Error>
where
for<'a> &'a mut Ex: PgExecutor<'a>,
{
let key = Key::for_interface(secrets, Some((self.id.clone(), id.clone()))).await?;
let ctrl = self.net_controller()?;
let cert = ctrl.ssl.with_certs(key, ip).await?;
let cert_dir = cert_dir(&self.id, id);
tokio::fs::create_dir_all(&cert_dir).await?;
export_key(
&cert.key().openssl_key_nistp256(),
&cert_dir.join(format!("{id}.key.pem")),
)
.await?;
export_cert(
&cert.fullchain_nistp256(),
&cert_dir.join(format!("{id}.cert.pem")),
)
.await?; // TODO: can upgrade to ed25519?
Ok(())
}
pub async fn remove_all(mut self) -> Result<(), Error> {
self.shutdown = true;
let mut errors = ErrorCollection::new();
if let Some(ctrl) = Weak::upgrade(&self.controller) {
for ((_, external), (key, rcs)) in std::mem::take(&mut self.lan) {
errors.handle(ctrl.remove_lan(&key, external, rcs).await);
}
for ((_, external), (key, rcs)) in std::mem::take(&mut self.tor) {
errors.handle(ctrl.remove_tor(&key, external, rcs).await);
}
std::mem::take(&mut self.dns);
errors.handle(ctrl.dns.gc(Some(self.id.clone()), self.ip).await);
errors.into_result()
} else {
tracing::warn!("NetService dropped after NetController is shutdown");
Err(Error::new(
eyre!("NetController is shutdown"),
crate::ErrorKind::Network,
))
}
}
}
impl Drop for NetService {
fn drop(&mut self) {
if !self.shutdown {
tracing::debug!("Dropping NetService for {}", self.id);
let svc = std::mem::replace(
self,
NetService {
shutdown: true,
id: Default::default(),
ip: Ipv4Addr::new(0, 0, 0, 0),
dns: Default::default(),
controller: Default::default(),
tor: Default::default(),
lan: Default::default(),
},
);
tokio::spawn(async move { svc.remove_all().await.unwrap() });
}
}
}

458
core/startos/src/net/ssl.rs Normal file
View File

@@ -0,0 +1,458 @@
use std::cmp::Ordering;
use std::collections::{BTreeMap, BTreeSet};
use std::net::IpAddr;
use std::path::Path;
use std::time::{SystemTime, UNIX_EPOCH};
use futures::FutureExt;
use libc::time_t;
use openssl::asn1::{Asn1Integer, Asn1Time};
use openssl::bn::{BigNum, MsbOption};
use openssl::ec::{EcGroup, EcKey};
use openssl::hash::MessageDigest;
use openssl::nid::Nid;
use openssl::pkey::{PKey, Private};
use openssl::x509::{X509Builder, X509Extension, X509NameBuilder, X509};
use openssl::*;
use rpc_toolkit::command;
use tokio::sync::{Mutex, RwLock};
use tracing::instrument;
use crate::account::AccountInfo;
use crate::context::RpcContext;
use crate::hostname::Hostname;
use crate::init::check_time_is_synchronized;
use crate::net::dhcp::ips;
use crate::net::keys::{Key, KeyInfo};
use crate::{Error, ErrorKind, ResultExt, SOURCE_DATE};
static CERTIFICATE_VERSION: i32 = 2; // X509 version 3 is actually encoded as '2' in the cert because fuck you.
fn unix_time(time: SystemTime) -> time_t {
time.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as time_t)
.or_else(|_| UNIX_EPOCH.elapsed().map(|d| -(d.as_secs() as time_t)))
.unwrap_or_default()
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct CertPair {
pub ed25519: X509,
pub nistp256: X509,
}
impl CertPair {
fn updated(
pair: Option<&Self>,
hostname: &Hostname,
signer: (&PKey<Private>, &X509),
applicant: &Key,
ip: BTreeSet<IpAddr>,
) -> Result<(Self, bool), Error> {
let mut updated = false;
let mut updated_cert = |cert: Option<&X509>, osk: PKey<Private>| -> Result<X509, Error> {
let mut ips = BTreeSet::new();
if let Some(cert) = cert {
ips.extend(
cert.subject_alt_names()
.iter()
.flatten()
.filter_map(|a| a.ipaddress())
.filter_map(|a| match a.len() {
4 => Some::<IpAddr>(<[u8; 4]>::try_from(a).unwrap().into()),
16 => Some::<IpAddr>(<[u8; 16]>::try_from(a).unwrap().into()),
_ => None,
}),
);
if cert
.not_before()
.compare(Asn1Time::days_from_now(0)?.as_ref())?
== Ordering::Less
&& cert
.not_after()
.compare(Asn1Time::days_from_now(30)?.as_ref())?
== Ordering::Greater
&& ips.is_superset(&ip)
{
return Ok(cert.clone());
}
}
ips.extend(ip.iter().copied());
updated = true;
make_leaf_cert(signer, (&osk, &SANInfo::new(&applicant, hostname, ips)))
};
Ok((
Self {
ed25519: updated_cert(pair.map(|c| &c.ed25519), applicant.openssl_key_ed25519())?,
nistp256: updated_cert(
pair.map(|c| &c.nistp256),
applicant.openssl_key_nistp256(),
)?,
},
updated,
))
}
}
pub async fn root_ca_start_time() -> Result<SystemTime, Error> {
Ok(if check_time_is_synchronized().await? {
SystemTime::now()
} else {
*SOURCE_DATE
})
}
#[derive(Debug)]
pub struct SslManager {
hostname: Hostname,
root_cert: X509,
int_key: PKey<Private>,
int_cert: X509,
cert_cache: RwLock<BTreeMap<Key, CertPair>>,
}
impl SslManager {
pub fn new(account: &AccountInfo, start_time: SystemTime) -> Result<Self, Error> {
let int_key = generate_key()?;
let int_cert = make_int_cert(
(&account.root_ca_key, &account.root_ca_cert),
&int_key,
start_time,
)?;
Ok(Self {
hostname: account.hostname.clone(),
root_cert: account.root_ca_cert.clone(),
int_key,
int_cert,
cert_cache: RwLock::new(BTreeMap::new()),
})
}
pub async fn with_certs(&self, key: Key, ip: IpAddr) -> Result<KeyInfo, Error> {
let mut ips = ips().await?;
ips.insert(ip);
let (pair, updated) = CertPair::updated(
self.cert_cache.read().await.get(&key),
&self.hostname,
(&self.int_key, &self.int_cert),
&key,
ips,
)?;
if updated {
self.cert_cache
.write()
.await
.insert(key.clone(), pair.clone());
}
Ok(key.with_certs(pair, self.int_cert.clone(), self.root_cert.clone()))
}
}
const EC_CURVE_NAME: nid::Nid = nid::Nid::X9_62_PRIME256V1;
lazy_static::lazy_static! {
static ref EC_GROUP: EcGroup = EcGroup::from_curve_name(EC_CURVE_NAME).unwrap();
static ref SSL_MUTEX: Mutex<()> = Mutex::new(()); // TODO: make thread safe
}
pub async fn export_key(key: &PKey<Private>, target: &Path) -> Result<(), Error> {
tokio::fs::write(target, key.private_key_to_pem_pkcs8()?)
.map(|res| res.with_ctx(|_| (ErrorKind::Filesystem, target.display().to_string())))
.await?;
Ok(())
}
pub async fn export_cert(chain: &[&X509], target: &Path) -> Result<(), Error> {
tokio::fs::write(
target,
chain
.into_iter()
.flat_map(|c| c.to_pem().unwrap())
.collect::<Vec<u8>>(),
)
.await?;
Ok(())
}
#[instrument(skip_all)]
fn rand_serial() -> Result<Asn1Integer, Error> {
let mut bn = BigNum::new()?;
bn.rand(64, MsbOption::MAYBE_ZERO, false)?;
let asn1 = Asn1Integer::from_bn(&bn)?;
Ok(asn1)
}
#[instrument(skip_all)]
pub fn generate_key() -> Result<PKey<Private>, Error> {
let new_key = EcKey::generate(EC_GROUP.as_ref())?;
let key = PKey::from_ec_key(new_key)?;
Ok(key)
}
#[instrument(skip_all)]
pub fn make_root_cert(
root_key: &PKey<Private>,
hostname: &Hostname,
start_time: SystemTime,
) -> Result<X509, Error> {
let mut builder = X509Builder::new()?;
builder.set_version(CERTIFICATE_VERSION)?;
let unix_start_time = unix_time(start_time);
let embargo = Asn1Time::from_unix(unix_start_time - 86400)?;
builder.set_not_before(&embargo)?;
let expiration = Asn1Time::from_unix(unix_start_time + (10 * 364 * 86400))?;
builder.set_not_after(&expiration)?;
builder.set_serial_number(&*rand_serial()?)?;
let mut subject_name_builder = X509NameBuilder::new()?;
subject_name_builder.append_entry_by_text("CN", &format!("{} Local Root CA", &*hostname.0))?;
subject_name_builder.append_entry_by_text("O", "Start9")?;
subject_name_builder.append_entry_by_text("OU", "StartOS")?;
let subject_name = subject_name_builder.build();
builder.set_subject_name(&subject_name)?;
builder.set_issuer_name(&subject_name)?;
builder.set_pubkey(&root_key)?;
// Extensions
let cfg = conf::Conf::new(conf::ConfMethod::default())?;
let ctx = builder.x509v3_context(None, Some(&cfg));
// subjectKeyIdentifier = hash
let subject_key_identifier =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_KEY_IDENTIFIER, "hash")?;
// basicConstraints = critical, CA:true, pathlen:0
let basic_constraints = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::BASIC_CONSTRAINTS,
"critical,CA:true",
)?;
// keyUsage = critical, digitalSignature, cRLSign, keyCertSign
let key_usage = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::KEY_USAGE,
"critical,digitalSignature,cRLSign,keyCertSign",
)?;
builder.append_extension(subject_key_identifier)?;
builder.append_extension(basic_constraints)?;
builder.append_extension(key_usage)?;
builder.sign(&root_key, MessageDigest::sha256())?;
let cert = builder.build();
Ok(cert)
}
#[instrument(skip_all)]
pub fn make_int_cert(
signer: (&PKey<Private>, &X509),
applicant: &PKey<Private>,
start_time: SystemTime,
) -> Result<X509, Error> {
let mut builder = X509Builder::new()?;
builder.set_version(CERTIFICATE_VERSION)?;
let unix_start_time = unix_time(start_time);
let embargo = Asn1Time::from_unix(unix_start_time - 86400)?;
builder.set_not_before(&embargo)?;
let expiration = Asn1Time::from_unix(unix_start_time + (10 * 364 * 86400))?;
builder.set_not_after(&expiration)?;
builder.set_serial_number(&*rand_serial()?)?;
let mut subject_name_builder = X509NameBuilder::new()?;
subject_name_builder.append_entry_by_text("CN", "StartOS Local Intermediate CA")?;
subject_name_builder.append_entry_by_text("O", "Start9")?;
subject_name_builder.append_entry_by_text("OU", "StartOS")?;
let subject_name = subject_name_builder.build();
builder.set_subject_name(&subject_name)?;
builder.set_issuer_name(signer.1.subject_name())?;
builder.set_pubkey(&applicant)?;
let cfg = conf::Conf::new(conf::ConfMethod::default())?;
let ctx = builder.x509v3_context(Some(&signer.1), Some(&cfg));
// subjectKeyIdentifier = hash
let subject_key_identifier =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_KEY_IDENTIFIER, "hash")?;
// authorityKeyIdentifier = keyid:always,issuer
let authority_key_identifier = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::AUTHORITY_KEY_IDENTIFIER,
"keyid:always,issuer",
)?;
// basicConstraints = critical, CA:true, pathlen:0
let basic_constraints = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::BASIC_CONSTRAINTS,
"critical,CA:true,pathlen:0",
)?;
// keyUsage = critical, digitalSignature, cRLSign, keyCertSign
let key_usage = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::KEY_USAGE,
"critical,digitalSignature,cRLSign,keyCertSign",
)?;
builder.append_extension(subject_key_identifier)?;
builder.append_extension(authority_key_identifier)?;
builder.append_extension(basic_constraints)?;
builder.append_extension(key_usage)?;
builder.sign(&signer.0, MessageDigest::sha256())?;
let cert = builder.build();
Ok(cert)
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum MaybeWildcard {
WithWildcard(String),
WithoutWildcard(String),
}
impl MaybeWildcard {
pub fn as_str(&self) -> &str {
match self {
MaybeWildcard::WithWildcard(s) => s.as_str(),
MaybeWildcard::WithoutWildcard(s) => s.as_str(),
}
}
}
impl std::fmt::Display for MaybeWildcard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MaybeWildcard::WithWildcard(dns) => write!(f, "DNS:{dns},DNS:*.{dns}"),
MaybeWildcard::WithoutWildcard(dns) => write!(f, "DNS:{dns}"),
}
}
}
#[derive(Debug)]
pub struct SANInfo {
pub dns: BTreeSet<MaybeWildcard>,
pub ips: BTreeSet<IpAddr>,
}
impl SANInfo {
pub fn new(key: &Key, hostname: &Hostname, ips: BTreeSet<IpAddr>) -> Self {
let mut dns = BTreeSet::new();
if let Some((id, _)) = key.interface() {
dns.insert(MaybeWildcard::WithWildcard(format!("{id}.embassy")));
dns.insert(MaybeWildcard::WithWildcard(key.local_address().to_string()));
} else {
dns.insert(MaybeWildcard::WithoutWildcard("embassy".to_owned()));
dns.insert(MaybeWildcard::WithWildcard(hostname.local_domain_name()));
dns.insert(MaybeWildcard::WithoutWildcard(hostname.no_dot_host_name()));
dns.insert(MaybeWildcard::WithoutWildcard("localhost".to_owned()));
}
dns.insert(MaybeWildcard::WithWildcard(key.tor_address().to_string()));
Self { dns, ips }
}
}
impl std::fmt::Display for SANInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut written = false;
for dns in &self.dns {
if written {
write!(f, ",")?;
}
written = true;
write!(f, "{dns}")?;
}
for ip in &self.ips {
if written {
write!(f, ",")?;
}
written = true;
write!(f, "IP:{ip}")?;
}
Ok(())
}
}
#[instrument(skip_all)]
pub fn make_leaf_cert(
signer: (&PKey<Private>, &X509),
applicant: (&PKey<Private>, &SANInfo),
) -> Result<X509, Error> {
let mut builder = X509Builder::new()?;
builder.set_version(CERTIFICATE_VERSION)?;
let embargo = Asn1Time::from_unix(unix_time(SystemTime::now()) - 86400)?;
builder.set_not_before(&embargo)?;
// Google Apple and Mozilla reject certificate horizons longer than 398 days
// https://techbeacon.com/security/google-apple-mozilla-enforce-1-year-max-security-certifications
let expiration = Asn1Time::days_from_now(397)?;
builder.set_not_after(&expiration)?;
builder.set_serial_number(&*rand_serial()?)?;
let mut subject_name_builder = X509NameBuilder::new()?;
subject_name_builder.append_entry_by_text(
"CN",
applicant
.1
.dns
.first()
.map(MaybeWildcard::as_str)
.unwrap_or("localhost"),
)?;
subject_name_builder.append_entry_by_text("O", "Start9")?;
subject_name_builder.append_entry_by_text("OU", "StartOS")?;
let subject_name = subject_name_builder.build();
builder.set_subject_name(&subject_name)?;
builder.set_issuer_name(signer.1.subject_name())?;
builder.set_pubkey(&applicant.0)?;
// Extensions
let cfg = conf::Conf::new(conf::ConfMethod::default())?;
let ctx = builder.x509v3_context(Some(&signer.1), Some(&cfg));
// subjectKeyIdentifier = hash
let subject_key_identifier =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_KEY_IDENTIFIER, "hash")?;
// authorityKeyIdentifier = keyid:always,issuer
let authority_key_identifier = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::AUTHORITY_KEY_IDENTIFIER,
"keyid,issuer:always",
)?;
let basic_constraints =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::BASIC_CONSTRAINTS, "CA:FALSE")?;
let key_usage = X509Extension::new_nid(
Some(&cfg),
Some(&ctx),
Nid::KEY_USAGE,
"critical,digitalSignature,keyEncipherment",
)?;
let san_string = applicant.1.to_string();
let subject_alt_name =
X509Extension::new_nid(Some(&cfg), Some(&ctx), Nid::SUBJECT_ALT_NAME, &san_string)?;
builder.append_extension(subject_key_identifier)?;
builder.append_extension(authority_key_identifier)?;
builder.append_extension(subject_alt_name)?;
builder.append_extension(basic_constraints)?;
builder.append_extension(key_usage)?;
builder.sign(&signer.0, MessageDigest::sha256())?;
let cert = builder.build();
Ok(cert)
}
#[command(subcommands(size))]
pub async fn ssl() -> Result<(), Error> {
Ok(())
}
#[command]
pub async fn size(#[context] ctx: RpcContext) -> Result<String, Error> {
Ok(format!(
"Cert Catch size: {}",
ctx.net_controller.ssl.cert_cache.read().await.len()
))
}

View File

@@ -0,0 +1,580 @@
use std::fs::Metadata;
use std::future::Future;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::UNIX_EPOCH;
use async_compression::tokio::bufread::GzipEncoder;
use color_eyre::eyre::eyre;
use digest::Digest;
use futures::FutureExt;
use http::header::ACCEPT_ENCODING;
use http::request::Parts as RequestParts;
use hyper::{Body, Method, Request, Response, StatusCode};
use include_dir::{include_dir, Dir};
use new_mime_guess::MimeGuess;
use openssl::hash::MessageDigest;
use openssl::x509::X509;
use rpc_toolkit::rpc_handler;
use tokio::fs::File;
use tokio::io::BufReader;
use tokio_util::io::ReaderStream;
use crate::context::{DiagnosticContext, InstallContext, RpcContext, SetupContext};
use crate::core::rpc_continuations::RequestGuid;
use crate::db::subscribe;
use crate::install::PKG_PUBLIC_DIR;
use crate::middleware::auth::{auth as auth_middleware, HasValidSession};
use crate::middleware::cors::cors;
use crate::middleware::db::db as db_middleware;
use crate::middleware::diagnostic::diagnostic as diagnostic_middleware;
use crate::net::HttpHandler;
use crate::{diagnostic_api, install_api, main_api, setup_api, Error, ErrorKind, ResultExt};
static NOT_FOUND: &[u8] = b"Not Found";
static METHOD_NOT_ALLOWED: &[u8] = b"Method Not Allowed";
static NOT_AUTHORIZED: &[u8] = b"Not Authorized";
static EMBEDDED_UIS: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/../../web/dist/static");
const PROXY_STRIP_HEADERS: &[&str] = &["cookie", "host", "origin", "referer", "user-agent"];
fn status_fn(_: i32) -> StatusCode {
StatusCode::OK
}
#[derive(Clone)]
pub enum UiMode {
Setup,
Diag,
Install,
Main,
}
impl UiMode {
fn path(&self, path: &str) -> PathBuf {
match self {
Self::Setup => Path::new("setup-wizard").join(path),
Self::Diag => Path::new("diagnostic-ui").join(path),
Self::Install => Path::new("install-wizard").join(path),
Self::Main => Path::new("ui").join(path),
}
}
}
pub async fn setup_ui_file_router(ctx: SetupContext) -> Result<HttpHandler, Error> {
let handler: HttpHandler = Arc::new(move |req| {
let ctx = ctx.clone();
let ui_mode = UiMode::Setup;
async move {
let res = match req.uri().path() {
path if path.starts_with("/rpc/") => {
let rpc_handler = rpc_handler!({
command: setup_api,
context: ctx,
status: status_fn,
middleware: [
cors,
]
});
rpc_handler(req)
.await
.map_err(|err| Error::new(eyre!("{}", err), crate::ErrorKind::Network))
}
_ => alt_ui(req, ui_mode).await,
};
match res {
Ok(data) => Ok(data),
Err(err) => Ok(server_error(err)),
}
}
.boxed()
});
Ok(handler)
}
pub async fn diag_ui_file_router(ctx: DiagnosticContext) -> Result<HttpHandler, Error> {
let handler: HttpHandler = Arc::new(move |req| {
let ctx = ctx.clone();
let ui_mode = UiMode::Diag;
async move {
let res = match req.uri().path() {
path if path.starts_with("/rpc/") => {
let rpc_handler = rpc_handler!({
command: diagnostic_api,
context: ctx,
status: status_fn,
middleware: [
cors,
diagnostic_middleware,
]
});
rpc_handler(req)
.await
.map_err(|err| Error::new(eyre!("{}", err), crate::ErrorKind::Network))
}
_ => alt_ui(req, ui_mode).await,
};
match res {
Ok(data) => Ok(data),
Err(err) => Ok(server_error(err)),
}
}
.boxed()
});
Ok(handler)
}
pub async fn install_ui_file_router(ctx: InstallContext) -> Result<HttpHandler, Error> {
let handler: HttpHandler = Arc::new(move |req| {
let ctx = ctx.clone();
let ui_mode = UiMode::Install;
async move {
let res = match req.uri().path() {
path if path.starts_with("/rpc/") => {
let rpc_handler = rpc_handler!({
command: install_api,
context: ctx,
status: status_fn,
middleware: [
cors,
]
});
rpc_handler(req)
.await
.map_err(|err| Error::new(eyre!("{}", err), crate::ErrorKind::Network))
}
_ => alt_ui(req, ui_mode).await,
};
match res {
Ok(data) => Ok(data),
Err(err) => Ok(server_error(err)),
}
}
.boxed()
});
Ok(handler)
}
pub async fn main_ui_server_router(ctx: RpcContext) -> Result<HttpHandler, Error> {
let handler: HttpHandler = Arc::new(move |req| {
let ctx = ctx.clone();
async move {
let res = match req.uri().path() {
path if path.starts_with("/rpc/") => {
let auth_middleware = auth_middleware(ctx.clone());
let db_middleware = db_middleware(ctx.clone());
let rpc_handler = rpc_handler!({
command: main_api,
context: ctx,
status: status_fn,
middleware: [
cors,
auth_middleware,
db_middleware,
]
});
rpc_handler(req)
.await
.map_err(|err| Error::new(eyre!("{}", err), crate::ErrorKind::Network))
}
"/ws/db" => subscribe(ctx, req).await,
path if path.starts_with("/ws/rpc/") => {
match RequestGuid::from(path.strip_prefix("/ws/rpc/").unwrap()) {
None => {
tracing::debug!("No Guid Path");
Ok::<_, Error>(bad_request())
}
Some(guid) => match ctx.get_ws_continuation_handler(&guid).await {
Some(cont) => match cont(req).await {
Ok::<_, Error>(r) => Ok::<_, Error>(r),
Err(err) => Ok::<_, Error>(server_error(err)),
},
_ => Ok::<_, Error>(not_found()),
},
}
}
path if path.starts_with("/rest/rpc/") => {
match RequestGuid::from(path.strip_prefix("/rest/rpc/").unwrap()) {
None => {
tracing::debug!("No Guid Path");
Ok::<_, Error>(bad_request())
}
Some(guid) => match ctx.get_rest_continuation_handler(&guid).await {
None => Ok::<_, Error>(not_found()),
Some(cont) => match cont(req).await {
Ok::<_, Error>(r) => Ok::<_, Error>(r),
Err(e) => Ok::<_, Error>(server_error(e)),
},
},
}
}
_ => main_embassy_ui(req, ctx).await,
};
match res {
Ok(data) => Ok(data),
Err(err) => Ok(server_error(err)),
}
}
.boxed()
});
Ok(handler)
}
async fn alt_ui(req: Request<Body>, ui_mode: UiMode) -> Result<Response<Body>, Error> {
let (request_parts, _body) = req.into_parts();
match &request_parts.method {
&Method::GET => {
let uri_path = ui_mode.path(
request_parts
.uri
.path()
.strip_prefix('/')
.unwrap_or(request_parts.uri.path()),
);
let file = EMBEDDED_UIS
.get_file(&*uri_path)
.or_else(|| EMBEDDED_UIS.get_file(&*ui_mode.path("index.html")));
if let Some(file) = file {
FileData::from_embedded(&request_parts, file)
.into_response(&request_parts)
.await
} else {
Ok(not_found())
}
}
_ => Ok(method_not_allowed()),
}
}
async fn if_authorized<
F: FnOnce() -> Fut,
Fut: Future<Output = Result<Response<Body>, Error>> + Send + Sync,
>(
ctx: &RpcContext,
parts: &RequestParts,
f: F,
) -> Result<Response<Body>, Error> {
if let Err(e) = HasValidSession::from_request_parts(parts, ctx).await {
un_authorized(e, parts.uri.path())
} else {
f().await
}
}
async fn main_embassy_ui(req: Request<Body>, ctx: RpcContext) -> Result<Response<Body>, Error> {
let (request_parts, _body) = req.into_parts();
match (
&request_parts.method,
request_parts
.uri
.path()
.strip_prefix('/')
.unwrap_or(request_parts.uri.path())
.split_once('/'),
) {
(&Method::GET, Some(("public", path))) => {
if_authorized(&ctx, &request_parts, || async {
let sub_path = Path::new(path);
if let Ok(rest) = sub_path.strip_prefix("package-data") {
FileData::from_path(
&request_parts,
&ctx.datadir.join(PKG_PUBLIC_DIR).join(rest),
)
.await?
.into_response(&request_parts)
.await
} else {
Ok(not_found())
}
})
.await
}
(&Method::GET, Some(("proxy", target))) => {
if_authorized(&ctx, &request_parts, || async {
let target = urlencoding::decode(target)?;
let res = ctx
.client
.get(target.as_ref())
.headers(
request_parts
.headers
.iter()
.filter(|(h, _)| {
!PROXY_STRIP_HEADERS
.iter()
.any(|bad| h.as_str().eq_ignore_ascii_case(bad))
})
.map(|(h, v)| (h.clone(), v.clone()))
.collect(),
)
.send()
.await
.with_kind(crate::ErrorKind::Network)?;
let mut hres = Response::builder().status(res.status());
for (h, v) in res.headers().clone() {
if let Some(h) = h {
hres = hres.header(h, v);
}
}
hres.body(Body::wrap_stream(res.bytes_stream()))
.with_kind(crate::ErrorKind::Network)
})
.await
}
(&Method::GET, Some(("eos", "local.crt"))) => {
cert_send(&ctx.account.read().await.root_ca_cert)
}
(&Method::GET, _) => {
let uri_path = UiMode::Main.path(
request_parts
.uri
.path()
.strip_prefix('/')
.unwrap_or(request_parts.uri.path()),
);
let file = EMBEDDED_UIS
.get_file(&*uri_path)
.or_else(|| EMBEDDED_UIS.get_file(&*UiMode::Main.path("index.html")));
if let Some(file) = file {
FileData::from_embedded(&request_parts, file)
.into_response(&request_parts)
.await
} else {
Ok(not_found())
}
}
_ => Ok(method_not_allowed()),
}
}
fn un_authorized(err: Error, path: &str) -> Result<Response<Body>, Error> {
tracing::warn!("unauthorized for {} @{:?}", err, path);
tracing::debug!("{:?}", err);
Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(NOT_AUTHORIZED.into())
.unwrap())
}
/// HTTP status code 404
fn not_found() -> Response<Body> {
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(NOT_FOUND.into())
.unwrap()
}
/// HTTP status code 405
fn method_not_allowed() -> Response<Body> {
Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(METHOD_NOT_ALLOWED.into())
.unwrap()
}
fn server_error(err: Error) -> Response<Body> {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(err.to_string().into())
.unwrap()
}
fn bad_request() -> Response<Body> {
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::empty())
.unwrap()
}
fn cert_send(cert: &X509) -> Result<Response<Body>, Error> {
let pem = cert.to_pem()?;
Response::builder()
.status(StatusCode::OK)
.header(
http::header::ETAG,
base32::encode(
base32::Alphabet::RFC4648 { padding: false },
&*cert.digest(MessageDigest::sha256())?,
)
.to_lowercase(),
)
.header(http::header::CONTENT_TYPE, "application/x-pem-file")
.header(http::header::CONTENT_LENGTH, pem.len())
.body(Body::from(pem))
.with_kind(ErrorKind::Network)
}
struct FileData {
data: Body,
len: Option<u64>,
encoding: Option<&'static str>,
e_tag: String,
mime: Option<String>,
}
impl FileData {
fn from_embedded(req: &RequestParts, file: &'static include_dir::File<'static>) -> Self {
let path = file.path();
let (encoding, data) = req
.headers
.get_all(ACCEPT_ENCODING)
.into_iter()
.filter_map(|h| h.to_str().ok())
.flat_map(|s| s.split(","))
.filter_map(|s| s.split(";").next())
.map(|s| s.trim())
.fold((None, file.contents()), |acc, e| {
if let Some(file) = (e == "br")
.then_some(())
.and_then(|_| EMBEDDED_UIS.get_file(format!("{}.br", path.display())))
{
(Some("br"), file.contents())
} else if let Some(file) = (e == "gzip" && acc.0 != Some("br"))
.then_some(())
.and_then(|_| EMBEDDED_UIS.get_file(format!("{}.gz", path.display())))
{
(Some("gzip"), file.contents())
} else {
acc
}
});
Self {
len: Some(data.len() as u64),
encoding,
data: data.into(),
e_tag: e_tag(path, None),
mime: MimeGuess::from_path(path)
.first()
.map(|m| m.essence_str().to_owned()),
}
}
async fn from_path(req: &RequestParts, path: &Path) -> Result<Self, Error> {
let encoding = req
.headers
.get_all(ACCEPT_ENCODING)
.into_iter()
.filter_map(|h| h.to_str().ok())
.flat_map(|s| s.split(","))
.filter_map(|s| s.split(";").next())
.map(|s| s.trim())
.any(|e| e == "gzip")
.then_some("gzip");
let file = File::open(path)
.await
.with_ctx(|_| (ErrorKind::Filesystem, path.display().to_string()))?;
let metadata = file
.metadata()
.await
.with_ctx(|_| (ErrorKind::Filesystem, path.display().to_string()))?;
let e_tag = e_tag(path, Some(&metadata));
let (len, data) = if encoding == Some("gzip") {
(
None,
Body::wrap_stream(ReaderStream::new(GzipEncoder::new(BufReader::new(file)))),
)
} else {
(
Some(metadata.len()),
Body::wrap_stream(ReaderStream::new(file)),
)
};
Ok(Self {
data,
len,
encoding,
e_tag,
mime: MimeGuess::from_path(path)
.first()
.map(|m| m.essence_str().to_owned()),
})
}
async fn into_response(self, req: &RequestParts) -> Result<Response<Body>, Error> {
let mut builder = Response::builder();
if let Some(mime) = self.mime {
builder = builder.header(http::header::CONTENT_TYPE, &*mime);
}
builder = builder.header(http::header::ETAG, &*self.e_tag);
builder = builder.header(
http::header::CACHE_CONTROL,
"public, max-age=21000000, immutable",
);
if req
.headers
.get_all(http::header::CONNECTION)
.iter()
.flat_map(|s| s.to_str().ok())
.flat_map(|s| s.split(","))
.any(|s| s.trim() == "keep-alive")
{
builder = builder.header(http::header::CONNECTION, "keep-alive");
}
if req
.headers
.get("if-none-match")
.and_then(|h| h.to_str().ok())
== Some(self.e_tag.as_ref())
{
builder = builder.status(StatusCode::NOT_MODIFIED);
builder.body(Body::empty())
} else {
if let Some(len) = self.len {
builder = builder.header(http::header::CONTENT_LENGTH, len);
}
if let Some(encoding) = self.encoding {
builder = builder.header(http::header::CONTENT_ENCODING, encoding);
}
builder.body(self.data)
}
.with_kind(ErrorKind::Network)
}
}
fn e_tag(path: &Path, metadata: Option<&Metadata>) -> String {
let mut hasher = sha2::Sha256::new();
hasher.update(format!("{:?}", path).as_bytes());
if let Some(modified) = metadata.and_then(|m| m.modified().ok()) {
hasher.update(
format!(
"{}",
modified
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
)
.as_bytes(),
);
}
let res = hasher.finalize();
format!(
"\"{}\"",
base32::encode(base32::Alphabet::RFC4648 { padding: false }, res.as_slice()).to_lowercase()
)
}

741
core/startos/src/net/tor.rs Normal file
View File

@@ -0,0 +1,741 @@
use std::collections::BTreeMap;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Weak};
use std::time::Duration;
use clap::ArgMatches;
use color_eyre::eyre::eyre;
use futures::future::BoxFuture;
use futures::{FutureExt, TryStreamExt};
use helpers::NonDetachingJoinHandle;
use itertools::Itertools;
use lazy_static::lazy_static;
use regex::Regex;
use rpc_toolkit::command;
use rpc_toolkit::yajrc::RpcError;
use tokio::net::TcpStream;
use tokio::process::Command;
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
use torut::control::{AsyncEvent, AuthenticatedConn, ConnError};
use torut::onion::{OnionAddressV3, TorSecretKeyV3};
use tracing::instrument;
use crate::context::{CliContext, RpcContext};
use crate::logs::{
cli_logs_generic_follow, cli_logs_generic_nofollow, fetch_logs, follow_logs, journalctl,
LogFollowResponse, LogResponse, LogSource,
};
use crate::util::serde::{display_serializable, IoFormat};
use crate::util::{display_none, Invoke};
use crate::{Error, ErrorKind, ResultExt as _};
pub const SYSTEMD_UNIT: &str = "tor@default";
const STARTING_HEALTH_TIMEOUT: u64 = 120; // 2min
enum ErrorLogSeverity {
Fatal { wipe_state: bool },
Unknown { wipe_state: bool },
}
lazy_static! {
static ref LOG_REGEXES: Vec<(Regex, ErrorLogSeverity)> = vec![(
Regex::new("This could indicate a route manipulation attack, network overload, bad local network connectivity, or a bug\\.").unwrap(),
ErrorLogSeverity::Unknown { wipe_state: true }
),(
Regex::new("died due to an invalid selected path").unwrap(),
ErrorLogSeverity::Fatal { wipe_state: false }
),(
Regex::new("Tor has not observed any network activity for the past").unwrap(),
ErrorLogSeverity::Unknown { wipe_state: false }
)];
static ref PROGRESS_REGEX: Regex = Regex::new("PROGRESS=([0-9]+)").unwrap();
}
#[test]
fn random_key() {
println!("x'{}'", hex::encode(rand::random::<[u8; 32]>()));
}
#[command(subcommands(list_services, logs, reset))]
pub fn tor() -> Result<(), Error> {
Ok(())
}
#[command(display(display_none))]
pub async fn reset(
#[context] ctx: RpcContext,
#[arg(rename = "wipe-state", short = 'w', long = "wipe-state")] wipe_state: bool,
#[arg] reason: String,
) -> Result<(), Error> {
ctx.net_controller
.tor
.reset(wipe_state, Error::new(eyre!("{reason}"), ErrorKind::Tor))
.await
}
fn display_services(services: Vec<OnionAddressV3>, matches: &ArgMatches) {
use prettytable::*;
if matches.is_present("format") {
return display_serializable(services, matches);
}
let mut table = Table::new();
for service in services {
let row = row![&service.to_string()];
table.add_row(row);
}
table.print_tty(false).unwrap();
}
#[command(rename = "list-services", display(display_services))]
pub async fn list_services(
#[context] ctx: RpcContext,
#[allow(unused_variables)]
#[arg(long = "format")]
format: Option<IoFormat>,
) -> Result<Vec<OnionAddressV3>, Error> {
ctx.net_controller.tor.list_services().await
}
#[command(
custom_cli(cli_logs(async, context(CliContext))),
subcommands(self(logs_nofollow(async)), logs_follow),
display(display_none)
)]
pub async fn logs(
#[arg(short = 'l', long = "limit")] limit: Option<usize>,
#[arg(short = 'c', long = "cursor")] cursor: Option<String>,
#[arg(short = 'B', long = "before", default)] before: bool,
#[arg(short = 'f', long = "follow", default)] follow: bool,
) -> Result<(Option<usize>, Option<String>, bool, bool), Error> {
Ok((limit, cursor, before, follow))
}
pub async fn cli_logs(
ctx: CliContext,
(limit, cursor, before, follow): (Option<usize>, Option<String>, bool, bool),
) -> Result<(), RpcError> {
if follow {
if cursor.is_some() {
return Err(RpcError::from(Error::new(
eyre!("The argument '--cursor <cursor>' cannot be used with '--follow'"),
crate::ErrorKind::InvalidRequest,
)));
}
if before {
return Err(RpcError::from(Error::new(
eyre!("The argument '--before' cannot be used with '--follow'"),
crate::ErrorKind::InvalidRequest,
)));
}
cli_logs_generic_follow(ctx, "net.tor.logs.follow", None, limit).await
} else {
cli_logs_generic_nofollow(ctx, "net.tor.logs", None, limit, cursor, before).await
}
}
pub async fn logs_nofollow(
_ctx: (),
(limit, cursor, before, _): (Option<usize>, Option<String>, bool, bool),
) -> Result<LogResponse, Error> {
fetch_logs(LogSource::Unit(SYSTEMD_UNIT), limit, cursor, before).await
}
#[command(rpc_only, rename = "follow", display(display_none))]
pub async fn logs_follow(
#[context] ctx: RpcContext,
#[parent_data] (limit, _, _, _): (Option<usize>, Option<String>, bool, bool),
) -> Result<LogFollowResponse, Error> {
follow_logs(ctx, LogSource::Unit(SYSTEMD_UNIT), limit).await
}
fn event_handler(_event: AsyncEvent<'static>) -> BoxFuture<'static, Result<(), ConnError>> {
async move { Ok(()) }.boxed()
}
pub struct TorController(TorControl);
impl TorController {
pub fn new(tor_control: SocketAddr, tor_socks: SocketAddr) -> Self {
TorController(TorControl::new(tor_control, tor_socks))
}
pub async fn add(
&self,
key: TorSecretKeyV3,
external: u16,
target: SocketAddr,
) -> Result<Arc<()>, Error> {
let (reply, res) = oneshot::channel();
self.0
.send
.send(TorCommand::AddOnion {
key,
external,
target,
reply,
})
.ok()
.ok_or_else(|| Error::new(eyre!("TorControl died"), ErrorKind::Tor))?;
res.await
.ok()
.ok_or_else(|| Error::new(eyre!("TorControl died"), ErrorKind::Tor))
}
pub async fn gc(
&self,
key: Option<TorSecretKeyV3>,
external: Option<u16>,
) -> Result<(), Error> {
self.0
.send
.send(TorCommand::GC { key, external })
.ok()
.ok_or_else(|| Error::new(eyre!("TorControl died"), ErrorKind::Tor))
}
pub async fn reset(&self, wipe_state: bool, context: Error) -> Result<(), Error> {
self.0
.send
.send(TorCommand::Reset {
wipe_state,
context,
})
.ok()
.ok_or_else(|| Error::new(eyre!("TorControl died"), ErrorKind::Tor))
}
pub async fn list_services(&self) -> Result<Vec<OnionAddressV3>, Error> {
let (reply, res) = oneshot::channel();
self.0
.send
.send(TorCommand::GetInfo {
query: "onions/current".into(),
reply,
})
.ok()
.ok_or_else(|| Error::new(eyre!("TorControl died"), ErrorKind::Tor))?;
res.await
.ok()
.ok_or_else(|| Error::new(eyre!("TorControl died"), ErrorKind::Tor))??
.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty())
.map(|l| l.parse().with_kind(ErrorKind::Tor))
.collect()
}
}
type AuthenticatedConnection = AuthenticatedConn<
TcpStream,
Box<dyn Fn(AsyncEvent<'static>) -> BoxFuture<'static, Result<(), ConnError>> + Send + Sync>,
>;
enum TorCommand {
AddOnion {
key: TorSecretKeyV3,
external: u16,
target: SocketAddr,
reply: oneshot::Sender<Arc<()>>,
},
GC {
key: Option<TorSecretKeyV3>,
external: Option<u16>,
},
GetInfo {
query: String,
reply: oneshot::Sender<Result<String, Error>>,
},
Reset {
wipe_state: bool,
context: Error,
},
}
#[instrument(skip_all)]
async fn torctl(
tor_control: SocketAddr,
tor_socks: SocketAddr,
recv: &mut mpsc::UnboundedReceiver<TorCommand>,
services: &mut BTreeMap<[u8; 64], BTreeMap<u16, BTreeMap<SocketAddr, Weak<()>>>>,
wipe_state: &AtomicBool,
health_timeout: &mut Duration,
) -> Result<(), Error> {
let bootstrap = async {
if Command::new("systemctl")
.arg("is-active")
.arg("--quiet")
.arg("tor")
.invoke(ErrorKind::Tor)
.await
.is_ok()
{
Command::new("systemctl")
.arg("stop")
.arg("tor")
.invoke(ErrorKind::Tor)
.await?;
for _ in 0..30 {
if TcpStream::connect(tor_control).await.is_err() {
break;
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
if TcpStream::connect(tor_control).await.is_ok() {
return Err(Error::new(
eyre!("Tor is failing to shut down"),
ErrorKind::Tor,
));
}
}
if wipe_state.load(std::sync::atomic::Ordering::SeqCst) {
tokio::fs::remove_dir_all("/var/lib/tor").await?;
wipe_state.store(false, std::sync::atomic::Ordering::SeqCst);
}
tokio::fs::create_dir_all("/var/lib/tor").await?;
Command::new("chown")
.arg("-R")
.arg("debian-tor")
.arg("/var/lib/tor")
.invoke(ErrorKind::Filesystem)
.await?;
Command::new("systemctl")
.arg("start")
.arg("tor")
.invoke(ErrorKind::Tor)
.await?;
let logs = journalctl(LogSource::Unit(SYSTEMD_UNIT), 0, None, false, true).await?;
let mut tcp_stream = None;
for _ in 0..60 {
if let Ok(conn) = TcpStream::connect(tor_control).await {
tcp_stream = Some(conn);
break;
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
let tcp_stream = tcp_stream.ok_or_else(|| {
Error::new(eyre!("Timed out waiting for tor to start"), ErrorKind::Tor)
})?;
tracing::info!("Tor is started");
let mut conn = torut::control::UnauthenticatedConn::new(tcp_stream);
let auth = conn
.load_protocol_info()
.await?
.make_auth_data()?
.ok_or_else(|| eyre!("Cookie Auth Not Available"))
.with_kind(crate::ErrorKind::Tor)?;
conn.authenticate(&auth).await?;
let mut connection: AuthenticatedConnection = conn.into_authenticated().await;
connection.set_async_event_handler(Some(Box::new(|event| event_handler(event))));
let mut bootstrapped = false;
let mut last_increment = (String::new(), Instant::now());
for _ in 0..300 {
match connection.get_info("status/bootstrap-phase").await {
Ok(a) => {
if a.contains("TAG=done") {
bootstrapped = true;
break;
}
if let Some(p) = PROGRESS_REGEX.captures(&a) {
if let Some(p) = p.get(1) {
if p.as_str() != &*last_increment.0 {
last_increment = (p.as_str().into(), Instant::now());
}
}
}
}
Err(e) => {
let e = Error::from(e);
tracing::error!("{}", e);
tracing::debug!("{:?}", e);
}
}
if last_increment.1.elapsed() > Duration::from_secs(30) {
return Err(Error::new(
eyre!("Tor stuck bootstrapping at {}%", last_increment.0),
ErrorKind::Tor,
));
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
if !bootstrapped {
return Err(Error::new(
eyre!("Timed out waiting for tor to bootstrap"),
ErrorKind::Tor,
));
}
Ok((connection, logs))
};
let pre_handler = async {
while let Some(command) = recv.recv().await {
match command {
TorCommand::AddOnion {
key,
external,
target,
reply,
} => {
let mut service = if let Some(service) = services.remove(&key.as_bytes()) {
service
} else {
BTreeMap::new()
};
let mut binding = service.remove(&external).unwrap_or_default();
let rc = if let Some(rc) =
Weak::upgrade(&binding.remove(&target).unwrap_or_default())
{
rc
} else {
Arc::new(())
};
binding.insert(target, Arc::downgrade(&rc));
service.insert(external, binding);
services.insert(key.as_bytes(), service);
reply.send(rc).unwrap_or_default();
}
TorCommand::GetInfo { reply, .. } => {
reply
.send(Err(Error::new(
eyre!("Tor has not finished bootstrapping..."),
ErrorKind::Tor,
)))
.unwrap_or_default();
}
TorCommand::GC { .. } => (),
TorCommand::Reset {
wipe_state: new_wipe_state,
context,
} => {
wipe_state.fetch_or(new_wipe_state, std::sync::atomic::Ordering::SeqCst);
return Err(context);
}
}
}
Ok(())
};
let (mut connection, mut logs) = tokio::select! {
res = bootstrap => res?,
res = pre_handler => return res,
};
let hck_key = TorSecretKeyV3::generate();
connection
.add_onion_v3(
&hck_key,
false,
false,
false,
None,
&mut [(80, SocketAddr::from(([127, 0, 0, 1], 80)))].iter(),
)
.await?;
for (key, service) in std::mem::take(services) {
let key = TorSecretKeyV3::from(key);
let bindings = service
.iter()
.flat_map(|(ext, int)| {
int.iter()
.find(|(_, rc)| rc.strong_count() > 0)
.map(|(addr, _)| (*ext, SocketAddr::from(*addr)))
})
.collect::<Vec<_>>();
if !bindings.is_empty() {
services.insert(key.as_bytes(), service);
connection
.add_onion_v3(&key, false, false, false, None, &mut bindings.iter())
.await?;
}
}
let handler = async {
while let Some(command) = recv.recv().await {
match command {
TorCommand::AddOnion {
key,
external,
target,
reply,
} => {
let mut rm_res = Ok(());
let onion_base = key
.public()
.get_onion_address()
.get_address_without_dot_onion();
let mut service = if let Some(service) = services.remove(&key.as_bytes()) {
rm_res = connection.del_onion(&onion_base).await;
service
} else {
BTreeMap::new()
};
let mut binding = service.remove(&external).unwrap_or_default();
let rc = if let Some(rc) =
Weak::upgrade(&binding.remove(&target).unwrap_or_default())
{
rc
} else {
Arc::new(())
};
binding.insert(target, Arc::downgrade(&rc));
service.insert(external, binding);
let bindings = service
.iter()
.flat_map(|(ext, int)| {
int.iter()
.find(|(_, rc)| rc.strong_count() > 0)
.map(|(addr, _)| (*ext, SocketAddr::from(*addr)))
})
.collect::<Vec<_>>();
services.insert(key.as_bytes(), service);
reply.send(rc).unwrap_or_default();
rm_res?;
connection
.add_onion_v3(&key, false, false, false, None, &mut bindings.iter())
.await?;
}
TorCommand::GC { key, external } => {
for key in if key.is_some() {
itertools::Either::Left(key.into_iter().map(|k| k.as_bytes()))
} else {
itertools::Either::Right(services.keys().cloned().collect_vec().into_iter())
} {
let key = TorSecretKeyV3::from(key);
let onion_base = key
.public()
.get_onion_address()
.get_address_without_dot_onion();
if let Some(mut service) = services.remove(&key.as_bytes()) {
for external in if external.is_some() {
itertools::Either::Left(external.into_iter())
} else {
itertools::Either::Right(
service.keys().copied().collect_vec().into_iter(),
)
} {
if let Some(mut binding) = service.remove(&external) {
binding = binding
.into_iter()
.filter(|(_, rc)| rc.strong_count() > 0)
.collect();
if !binding.is_empty() {
service.insert(external, binding);
}
}
}
let rm_res = connection.del_onion(&onion_base).await;
if !service.is_empty() {
let bindings = service
.iter()
.flat_map(|(ext, int)| {
int.iter()
.find(|(_, rc)| rc.strong_count() > 0)
.map(|(addr, _)| (*ext, SocketAddr::from(*addr)))
})
.collect::<Vec<_>>();
if !bindings.is_empty() {
services.insert(key.as_bytes(), service);
}
rm_res?;
if !bindings.is_empty() {
connection
.add_onion_v3(
&key,
false,
false,
false,
None,
&mut bindings.iter(),
)
.await?;
}
} else {
rm_res?;
}
}
}
}
TorCommand::GetInfo { query, reply } => {
reply
.send(connection.get_info(&query).await.with_kind(ErrorKind::Tor))
.unwrap_or_default();
}
TorCommand::Reset {
wipe_state: new_wipe_state,
context,
} => {
wipe_state.fetch_or(new_wipe_state, std::sync::atomic::Ordering::SeqCst);
return Err(context);
}
}
}
Ok(())
};
let log_parser = async {
while let Some(log) = logs.try_next().await? {
for (regex, severity) in &*LOG_REGEXES {
if regex.is_match(&log.message) {
let (check, wipe_state) = match severity {
ErrorLogSeverity::Fatal { wipe_state } => (false, *wipe_state),
ErrorLogSeverity::Unknown { wipe_state } => (true, *wipe_state),
};
if !check
|| tokio::time::timeout(
Duration::from_secs(30),
tokio_socks::tcp::Socks5Stream::connect(
tor_socks,
(hck_key.public().get_onion_address().to_string(), 80),
),
)
.await
.map_err(|e| tracing::warn!("Tor is confirmed to be down: {e}"))
.and_then(|a| {
a.map_err(|e| tracing::warn!("Tor is confirmed to be down: {e}"))
})
.is_err()
{
if wipe_state {
Command::new("systemctl")
.arg("stop")
.arg("tor")
.invoke(ErrorKind::Tor)
.await?;
tokio::fs::remove_dir_all("/var/lib/tor").await?;
}
return Err(Error::new(eyre!("{}", log.message), ErrorKind::Tor));
}
}
}
}
Err(Error::new(eyre!("Log stream terminated"), ErrorKind::Tor))
};
let health_checker = async {
let mut last_success = Instant::now();
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
if tokio::time::timeout(
Duration::from_secs(30),
tokio_socks::tcp::Socks5Stream::connect(
tor_socks,
(hck_key.public().get_onion_address().to_string(), 80),
),
)
.await
.map_err(|e| e.to_string())
.and_then(|e| e.map_err(|e| e.to_string()))
.is_err()
{
if last_success.elapsed() > *health_timeout {
let err = Error::new(eyre!("Tor health check failed for longer than current timeout ({health_timeout:?})"), crate::ErrorKind::Tor);
*health_timeout *= 2;
wipe_state.store(true, std::sync::atomic::Ordering::SeqCst);
return Err(err);
}
} else {
last_success = Instant::now();
}
}
};
tokio::select! {
res = handler => res?,
res = log_parser => res?,
res = health_checker => res?,
}
Ok(())
}
struct TorControl {
_thread: NonDetachingJoinHandle<()>,
send: mpsc::UnboundedSender<TorCommand>,
}
impl TorControl {
pub fn new(tor_control: SocketAddr, tor_socks: SocketAddr) -> Self {
let (send, mut recv) = mpsc::unbounded_channel();
Self {
_thread: tokio::spawn(async move {
let mut services = BTreeMap::new();
let wipe_state = AtomicBool::new(false);
let mut health_timeout = Duration::from_secs(STARTING_HEALTH_TIMEOUT);
while let Err(e) = torctl(
tor_control,
tor_socks,
&mut recv,
&mut services,
&wipe_state,
&mut health_timeout,
)
.await
{
tracing::error!("{e}: Restarting tor");
tracing::debug!("{e:?}");
}
tracing::info!("TorControl is shut down.")
})
.into(),
send,
}
}
}
#[tokio::test]
#[ignore]
async fn test() {
let mut conn = torut::control::UnauthenticatedConn::new(
TcpStream::connect(SocketAddr::from(([127, 0, 0, 1], 9051)))
.await
.unwrap(), // TODO
);
let auth = conn
.load_protocol_info()
.await
.unwrap()
.make_auth_data()
.unwrap()
.ok_or_else(|| eyre!("Cookie Auth Not Available"))
.with_kind(crate::ErrorKind::Tor)
.unwrap();
conn.authenticate(&auth).await.unwrap();
let mut connection: AuthenticatedConn<
TcpStream,
fn(AsyncEvent<'static>) -> BoxFuture<'static, Result<(), ConnError>>,
> = conn.into_authenticated().await;
let tor_key = torut::onion::TorSecretKeyV3::generate();
connection.get_conf("SocksPort").await.unwrap();
connection
.add_onion_v3(
&tor_key,
false,
false,
false,
None,
&mut [(443_u16, SocketAddr::from(([127, 0, 0, 1], 8443)))].iter(),
)
.await
.unwrap();
connection
.del_onion(
&tor_key
.public()
.get_onion_address()
.get_address_without_dot_onion(),
)
.await
.unwrap();
connection
.add_onion_v3(
&tor_key,
false,
false,
false,
None,
&mut [(8443_u16, SocketAddr::from(([127, 0, 0, 1], 8443)))].iter(),
)
.await
.unwrap();
}

View File

@@ -0,0 +1,166 @@
use std::convert::Infallible;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use std::path::Path;
use async_stream::try_stream;
use color_eyre::eyre::eyre;
use futures::stream::BoxStream;
use futures::{StreamExt, TryStreamExt};
use ipnet::{Ipv4Net, Ipv6Net};
use tokio::net::{TcpListener, TcpStream};
use tokio::process::Command;
use crate::util::Invoke;
use crate::Error;
fn parse_iface_ip(output: &str) -> Result<Vec<&str>, Error> {
let output = output.trim();
if output.is_empty() {
return Ok(Vec::new());
}
let mut res = Vec::new();
for line in output.lines() {
if let Some(ip) = line.split_ascii_whitespace().nth(3) {
res.push(ip)
} else {
return Err(Error::new(
eyre!("malformed output from `ip`"),
crate::ErrorKind::Network,
));
}
}
Ok(res)
}
pub async fn get_iface_ipv4_addr(iface: &str) -> Result<Option<(Ipv4Addr, Ipv4Net)>, Error> {
Ok(parse_iface_ip(&String::from_utf8(
Command::new("ip")
.arg("-4")
.arg("-o")
.arg("addr")
.arg("show")
.arg(iface)
.invoke(crate::ErrorKind::Network)
.await?,
)?)?
.into_iter()
.map(|s| Ok::<_, Error>((s.split("/").next().unwrap().parse()?, s.parse()?)))
.next()
.transpose()?)
}
pub async fn get_iface_ipv6_addr(iface: &str) -> Result<Option<(Ipv6Addr, Ipv6Net)>, Error> {
Ok(parse_iface_ip(&String::from_utf8(
Command::new("ip")
.arg("-6")
.arg("-o")
.arg("addr")
.arg("show")
.arg(iface)
.invoke(crate::ErrorKind::Network)
.await?,
)?)?
.into_iter()
.find(|ip| !ip.starts_with("fe80::"))
.map(|s| Ok::<_, Error>((s.split("/").next().unwrap().parse()?, s.parse()?)))
.transpose()?)
}
pub async fn iface_is_physical(iface: &str) -> bool {
tokio::fs::metadata(Path::new("/sys/class/net").join(iface).join("device"))
.await
.is_ok()
}
pub async fn iface_is_wireless(iface: &str) -> bool {
tokio::fs::metadata(Path::new("/sys/class/net").join(iface).join("wireless"))
.await
.is_ok()
}
pub fn list_interfaces() -> BoxStream<'static, Result<String, Error>> {
try_stream! {
let mut ifaces = tokio::fs::read_dir("/sys/class/net").await?;
while let Some(iface) = ifaces.next_entry().await? {
if let Some(iface) = iface.file_name().into_string().ok() {
yield iface;
}
}
}
.boxed()
}
pub async fn find_wifi_iface() -> Result<Option<String>, Error> {
let mut ifaces = list_interfaces();
while let Some(iface) = ifaces.try_next().await? {
if iface_is_wireless(&iface).await {
return Ok(Some(iface));
}
}
Ok(None)
}
pub async fn find_eth_iface() -> Result<String, Error> {
let mut ifaces = list_interfaces();
while let Some(iface) = ifaces.try_next().await? {
if iface_is_physical(&iface).await && !iface_is_wireless(&iface).await {
return Ok(iface);
}
}
Err(Error::new(
eyre!("Could not detect ethernet interface"),
crate::ErrorKind::Network,
))
}
#[pin_project::pin_project]
pub struct SingleAccept<T>(Option<T>);
impl<T> SingleAccept<T> {
pub fn new(conn: T) -> Self {
Self(Some(conn))
}
}
impl<T> hyper::server::accept::Accept for SingleAccept<T> {
type Conn = T;
type Error = Infallible;
fn poll_accept(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
std::task::Poll::Ready(self.project().0.take().map(Ok))
}
}
pub struct TcpListeners {
listeners: Vec<TcpListener>,
}
impl TcpListeners {
pub fn new(listeners: impl IntoIterator<Item = TcpListener>) -> Self {
Self {
listeners: listeners.into_iter().collect(),
}
}
pub async fn accept(&self) -> std::io::Result<(TcpStream, SocketAddr)> {
futures::future::select_all(self.listeners.iter().map(|l| Box::pin(l.accept())))
.await
.0
}
}
impl hyper::server::accept::Accept for TcpListeners {
type Conn = TcpStream;
type Error = std::io::Error;
fn poll_accept(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<Self::Conn, Self::Error>>> {
for listener in self.listeners.iter() {
let poll = listener.poll_accept(cx);
if poll.is_ready() {
return poll.map(|a| a.map(|a| a.0)).map(Some);
}
}
std::task::Poll::Pending
}
}

Some files were not shown because too many files have changed in this diff Show More