diff --git a/backend/src/backup/restore.rs b/backend/src/backup/restore.rs index e33871fc0..6fe2ebbfc 100644 --- a/backend/src/backup/restore.rs +++ b/backend/src/backup/restore.rs @@ -1,13 +1,13 @@ -use std::collections::BTreeMap; use std::path::Path; use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::Duration; +use std::{collections::BTreeMap, pin::Pin}; use clap::ArgMatches; use color_eyre::eyre::eyre; -use futures::future::BoxFuture; -use futures::FutureExt; +use futures::{future::BoxFuture, stream, Future}; +use futures::{FutureExt, StreamExt}; use openssl::x509::X509; use patch_db::{DbHandle, PatchDbHandle}; use rpc_toolkit::command; @@ -64,52 +64,34 @@ pub async fn restore_packages_rpc( let (backup_guard, tasks, _) = restore_packages(&ctx, &mut db, backup_guard, ids).await?; tokio::spawn(async move { - let res = futures::future::join_all(tasks).await; - for res in res { - match res.with_kind(crate::ErrorKind::Unknown) { - Ok((Ok(_), _)) => (), - Ok((Err(err), package_id)) => { - if let Err(err) = ctx - .notification_manager - .notify( - &mut db, - Some(package_id.clone()), - NotificationLevel::Error, - "Restoration Failure".to_string(), - format!("Error restoring package {}: {}", package_id, err), - (), - None, - ) - .await - { - tracing::error!("Failed to notify: {}", err); - tracing::debug!("{:?}", err); - }; - tracing::error!("Error restoring package {}: {}", package_id, err); - tracing::debug!("{:?}", err); - } - Err(e) => { - if let Err(err) = ctx - .notification_manager - .notify( - &mut db, - None, - NotificationLevel::Error, - "Restoration Failure".to_string(), - format!("Error during restoration: {}", e), - (), - None, - ) - .await - { - tracing::error!("Failed to notify: {}", err); + stream::iter(tasks.into_iter().map(|x| (x, ctx.clone()))) + .for_each_concurrent(5, |(res, ctx)| async move { + let mut db = ctx.db.handle(); + match res.await { + (Ok(_), _) => (), + (Err(err), package_id) => { + if let Err(err) = ctx + .notification_manager + .notify( + &mut db, + Some(package_id.clone()), + NotificationLevel::Error, + "Restoration Failure".to_string(), + format!("Error restoring package {}: {}", package_id, err), + (), + None, + ) + .await + { + tracing::error!("Failed to notify: {}", err); + tracing::debug!("{:?}", err); + }; + tracing::error!("Error restoring package {}: {}", package_id, err); tracing::debug!("{:?}", err); } - tracing::error!("Error restoring packages: {}", e); - tracing::debug!("{:?}", e); } - } - } + }) + .await; if let Err(e) = backup_guard.unmount().await { tracing::error!("Error unmounting backup drive: {}", e); tracing::debug!("{:?}", e); @@ -257,40 +239,31 @@ pub async fn recover_full_embassy( .collect(); let (backup_guard, tasks, progress_info) = restore_packages(&rpc_ctx, &mut db, backup_guard, ids).await?; - + let task_consumer_rpc_ctx = rpc_ctx.clone(); tokio::select! { - res = futures::future::join_all(tasks) => { - for res in res { - match res.with_kind(crate::ErrorKind::Unknown) { - Ok((Ok(_), _)) => (), - Ok((Err(err), package_id)) => { - if let Err(err) = rpc_ctx.notification_manager.notify( - &mut db, - Some(package_id.clone()), - NotificationLevel::Error, - "Restoration Failure".to_string(), format!("Error restoring package {}: {}", package_id,err), (), None).await{ - tracing::error!("Failed to notify: {}", err); + _ = async move { + stream::iter(tasks.into_iter().map(|x| (x, task_consumer_rpc_ctx.clone()))) + .for_each_concurrent(5, |(res, ctx)| async move { + let mut db = ctx.db.handle(); + match res.await { + (Ok(_), _) => (), + (Err(err), package_id) => { + if let Err(err) = ctx.notification_manager.notify( + &mut db, + Some(package_id.clone()), + NotificationLevel::Error, + "Restoration Failure".to_string(), format!("Error restoring package {}: {}", package_id,err), (), None).await{ + tracing::error!("Failed to notify: {}", err); + tracing::debug!("{:?}", err); + }; + tracing::error!("Error restoring package {}: {}", package_id, err); tracing::debug!("{:?}", err); - }; - tracing::error!("Error restoring package {}: {}", package_id, err); - tracing::debug!("{:?}", err); - }, - Err(e) => { - if let Err(err) = rpc_ctx.notification_manager.notify( - &mut db, - None, - NotificationLevel::Error, - "Restoration Failure".to_string(), format!("Error during restoration: {}", e), (), None).await { + }, + } + }).await; - tracing::error!("Failed to notify: {}", err); - tracing::debug!("{:?}", err); - } - tracing::error!("Error restoring packages: {}", e); - tracing::debug!("{:?}", e); - }, + } => { - } - } }, _ = approximate_progress_loop(&ctx, &rpc_ctx, progress_info) => unreachable!(concat!(module_path!(), "::approximate_progress_loop should not terminate")), } @@ -314,7 +287,7 @@ async fn restore_packages( ) -> Result< ( BackupMountGuard, - Vec, PackageId)>>, + Vec, PackageId)>>, ProgressInfo, ), Error, @@ -333,7 +306,7 @@ async fn restore_packages( .insert(id.clone(), dir_size(backup_dir(&id)).await?); progress_info.target_volume_size.insert(id.clone(), 0); let package_id = id.clone(); - tasks.push(tokio::spawn( + tasks.push( async move { if let Err(e) = task.await { tracing::error!("Error restoring package {}: {}", id, e); @@ -343,8 +316,9 @@ async fn restore_packages( Ok(()) } } - .map(|x| (x, package_id)), - )); + .map(|x| (x, package_id)) + .boxed(), + ); } Ok((backup_guard, tasks, progress_info))