diff --git a/appmgr/src/db/model.rs b/appmgr/src/db/model.rs index c9bbdd1e8..59e02521d 100644 --- a/appmgr/src/db/model.rs +++ b/appmgr/src/db/model.rs @@ -79,6 +79,7 @@ pub struct ServerInfo { pub unread_notification_count: u64, pub connection_addresses: ConnectionAddresses, pub share_stats: bool, + #[model] pub update_progress: Option, } @@ -90,7 +91,7 @@ pub enum ServerStatus { BackingUp, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Deserialize, Serialize, HasModel)] #[serde(rename_all = "kebab-case")] pub struct UpdateProgress { pub size: Option, diff --git a/appmgr/src/update/mod.rs b/appmgr/src/update/mod.rs index 2a20afafb..3017b723c 100644 --- a/appmgr/src/update/mod.rs +++ b/appmgr/src/update/mod.rs @@ -188,6 +188,7 @@ async fn maybe_do_update(ctx: RpcContext) -> Result>, Error let mounted_boot = mount_label(Boot).await?; let (new_label, _current_label) = query_mounted_label().await?; let (size, download) = download_file( + ctx.db.handle(), &EosUrl { base: info.eos_marketplace.clone(), version: latest_version, @@ -204,8 +205,8 @@ async fn maybe_do_update(ctx: RpcContext) -> Result>, Error let rev = tx.commit(None).await?; tokio::spawn(async move { - let res = do_update(download, new_label, mounted_boot).await; let mut db = ctx.db.handle(); + let res = do_update(download, new_label, mounted_boot).await; let mut info = crate::db::DatabaseModel::new() .server_info() .get_mut(&mut db) @@ -276,37 +277,38 @@ impl std::fmt::Display for EosUrl { } } -async fn download_file<'a>( +async fn download_file<'a, Db: DbHandle + 'a>( + mut db: Db, eos_url: &EosUrl, new_label: NewLabel, ) -> Result<(Option, impl Future> + 'a), Error> { let download_request = reqwest::get(eos_url.to_string()) .await .with_kind(ErrorKind::Network)?; - Ok(( - download_request - .headers() - .get("content-length") - .and_then(|a| a.to_str().ok()) - .map(|l| l.parse()) - .transpose()?, - async move { - let hash_from_header: String = "".to_owned(); // download_request - // .headers() - // .get(HEADER_KEY) - // .ok_or_else(|| Error::new(anyhow!("No {} in headers", HEADER_KEY), ErrorKind::Network))? - // .to_str() - // .with_kind(ErrorKind::InvalidRequest)? - // .to_owned(); - let stream_download = download_request.bytes_stream(); - let file_sum = write_stream_to_label(stream_download, new_label).await?; - check_download(&hash_from_header, file_sum).await?; - Ok(()) - }, - )) + let size = download_request + .headers() + .get("content-length") + .and_then(|a| a.to_str().ok()) + .map(|l| l.parse()) + .transpose()?; + Ok((size, async move { + let hash_from_header: String = "".to_owned(); // download_request + // .headers() + // .get(HEADER_KEY) + // .ok_or_else(|| Error::new(anyhow!("No {} in headers", HEADER_KEY), ErrorKind::Network))? + // .to_str() + // .with_kind(ErrorKind::InvalidRequest)? + // .to_owned(); + let stream_download = download_request.bytes_stream(); + let file_sum = write_stream_to_label(&mut db, size, stream_download, new_label).await?; + check_download(&hash_from_header, file_sum).await?; + Ok(()) + })) } -async fn write_stream_to_label( +async fn write_stream_to_label( + db: &mut Db, + size: Option, stream_download: impl Stream>, file: NewLabel, ) -> Result, Error> { @@ -318,10 +320,17 @@ async fn write_stream_to_label( .with_kind(ErrorKind::Filesystem)?; let mut hasher = Sha256::new(); pin!(stream_download); + let mut downloaded = 0; while let Some(Ok(item)) = stream_download.next().await { file.write_all(&item) .await .with_kind(ErrorKind::Filesystem)?; + downloaded += item.len() as u64; + crate::db::DatabaseModel::new() + .server_info() + .update_progress() + .put(db, &UpdateProgress { size, downloaded }) + .await?; hasher.update(item); } file.flush().await.with_kind(ErrorKind::Filesystem)?;