use std::ops::Deref; use std::path::Path; use std::sync::Arc; use std::time::Duration; use futures::{Future, StreamExt}; use helpers::NonDetachingJoinHandle; use imbl_value::InternedString; use josekit::jwk::Jwk; use patch_db::PatchDb; use rpc_toolkit::Context; use serde::{Deserialize, Serialize}; use tokio::sync::OnceCell; use tokio::sync::broadcast::Sender; use tracing::instrument; use ts_rs::TS; use crate::MAIN_DATA; use crate::account::AccountInfo; use crate::context::RpcContext; use crate::context::config::ServerConfig; use crate::disk::OsPartitionInfo; use crate::hostname::Hostname; use crate::net::gateway::UpgradableListener; use crate::net::web_server::{WebServer, WebServerAcceptorSetter}; use crate::prelude::*; use crate::progress::FullProgressTracker; use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations}; use crate::setup::SetupProgress; use crate::shutdown::Shutdown; use crate::util::net::WebSocketExt; lazy_static::lazy_static! { pub static ref CURRENT_SECRET: Jwk = Jwk::generate_ec_key(josekit::jwk::alg::ec::EcCurve::P256).unwrap_or_else(|e| { tracing::debug!("{:?}", e); tracing::error!("Couldn't generate ec key"); panic!("Couldn't generate ec key") }); } #[derive(Debug, Clone, Deserialize, Serialize, TS)] #[serde(rename_all = "camelCase")] #[ts(export)] pub struct SetupResult { pub tor_addresses: Vec, #[ts(type = "string")] pub hostname: Hostname, #[ts(type = "string")] pub lan_address: InternedString, pub root_ca: String, } impl TryFrom<&AccountInfo> for SetupResult { type Error = Error; fn try_from(value: &AccountInfo) -> Result { Ok(Self { tor_addresses: value .tor_keys .iter() .map(|tor_key| format!("https://{}", tor_key.onion_address())) .collect(), hostname: value.hostname.clone(), lan_address: value.hostname.lan_address(), root_ca: String::from_utf8(value.root_ca_cert.to_pem()?)?, }) } } pub struct SetupContextSeed { pub webserver: WebServerAcceptorSetter, pub config: ServerConfig, pub os_partitions: OsPartitionInfo, pub disable_encryption: bool, pub progress: FullProgressTracker, pub task: OnceCell>, pub result: OnceCell>, pub disk_guid: OnceCell>, pub shutdown: Sender>, pub rpc_continuations: RpcContinuations, } #[derive(Clone)] pub struct SetupContext(Arc); impl SetupContext { #[instrument(skip_all)] pub fn init( webserver: &WebServer, config: &ServerConfig, ) -> Result { let (shutdown, _) = tokio::sync::broadcast::channel(1); let mut progress = FullProgressTracker::new(); progress.enable_logging(true); Ok(Self(Arc::new(SetupContextSeed { webserver: webserver.acceptor_setter(), config: config.clone(), os_partitions: config.os_partitions.clone().ok_or_else(|| { Error::new( eyre!("missing required configuration: `os-partitions`"), ErrorKind::NotFound, ) })?, disable_encryption: config.disable_encryption.unwrap_or(false), progress, task: OnceCell::new(), result: OnceCell::new(), disk_guid: OnceCell::new(), shutdown, rpc_continuations: RpcContinuations::new(), }))) } #[instrument(skip_all)] pub async fn db(&self) -> Result { let db_path = Path::new(MAIN_DATA).join("embassy.db"); let db = PatchDb::open(&db_path) .await .with_ctx(|_| (crate::ErrorKind::Filesystem, db_path.display().to_string()))?; Ok(db) } pub fn run_setup(&self, f: F) -> Result<(), Error> where F: FnOnce() -> Fut + Send + 'static, Fut: Future> + Send, { let local_ctx = self.clone(); self.task .set( tokio::spawn(async move { local_ctx .result .get_or_init(|| async { match f().await { Ok(res) => { tracing::info!("Setup complete!"); Ok(res) } Err(e) => { tracing::error!("Setup failed: {e}"); tracing::debug!("{e:?}"); Err(e) } } }) .await; local_ctx.progress.complete(); }) .into(), ) .map_err(|_| { if self.result.initialized() { Error::new(eyre!("Setup already complete"), ErrorKind::InvalidRequest) } else { Error::new( eyre!("Setup already in progress"), ErrorKind::InvalidRequest, ) } })?; Ok(()) } pub async fn progress(&self) -> SetupProgress { use axum::extract::ws; let guid = Guid::new(); let progress_tracker = self.progress.clone(); let progress = progress_tracker.snapshot(); self.rpc_continuations .add( guid.clone(), RpcContinuation::ws( |mut ws| async move { if let Err(e) = async { let mut stream = progress_tracker.stream(Some(Duration::from_millis(100))); loop { tokio::select! { progress = stream.next() => { if let Some(progress) = progress { ws.send(ws::Message::Text( serde_json::to_string(&progress) .with_kind(ErrorKind::Serialization)? .into(), )) .await .with_kind(ErrorKind::Network)?; if progress.overall.is_complete() { return ws.normal_close("complete").await; } } else { return ws.normal_close("complete").await; } } msg = ws.recv() => { if msg.transpose().with_kind(ErrorKind::Network)?.is_none() { return Ok(()) } } } } } .await { tracing::error!("Error in setup progress websocket: {e}"); tracing::debug!("{e:?}"); } }, Duration::from_secs(30), ), ) .await; SetupProgress { progress, guid } } } impl AsRef for SetupContext { fn as_ref(&self) -> &Jwk { &*CURRENT_SECRET } } impl AsRef for SetupContext { fn as_ref(&self) -> &RpcContinuations { &self.rpc_continuations } } impl Context for SetupContext {} impl Deref for SetupContext { type Target = SetupContextSeed; fn deref(&self) -> &Self::Target { &*self.0 } }