use std::panic::UnwindSafe; use std::sync::Arc; use std::time::Duration; use futures::Future; use imbl_value::{InOMap, InternedString}; use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use itertools::Itertools; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncSeek, AsyncWrite}; use tokio::sync::{mpsc, watch}; use crate::db::model::DatabaseModel; use crate::prelude::*; lazy_static::lazy_static! { static ref SPINNER: ProgressStyle = ProgressStyle::with_template("{spinner} {msg}...").unwrap(); static ref PERCENTAGE: ProgressStyle = ProgressStyle::with_template("{msg} {percent}% {wide_bar} [{bytes}/{total_bytes}] [{binary_bytes_per_sec} {eta}]").unwrap(); static ref BYTES: ProgressStyle = ProgressStyle::with_template("{spinner} {wide_msg} [{bytes}/?] [{binary_bytes_per_sec} {elapsed}]").unwrap(); } #[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord)] #[serde(untagged)] pub enum Progress { Complete(bool), Progress { done: u64, total: Option }, } impl Progress { pub fn new() -> Self { Progress::Complete(false) } pub fn update_bar(self, bar: &ProgressBar) { match self { Self::Complete(false) => { bar.set_style(SPINNER.clone()); bar.tick(); } Self::Complete(true) => { bar.finish(); } Self::Progress { done, total: None } => { bar.set_style(BYTES.clone()); bar.set_position(done); bar.tick(); } Self::Progress { done, total: Some(total), } => { bar.set_style(PERCENTAGE.clone()); bar.set_position(done); bar.set_length(total); bar.tick(); } } } pub fn set_done(&mut self, done: u64) { *self = match *self { Self::Complete(false) => Self::Progress { done, total: None }, Self::Progress { mut done, total } => { if let Some(total) = total { if done > total { done = total; } } Self::Progress { done, total } } Self::Complete(true) => Self::Complete(true), }; } pub fn set_total(&mut self, total: u64) { *self = match *self { Self::Complete(false) => Self::Progress { done: 0, total: Some(total), }, Self::Progress { done, .. } => Self::Progress { done, total: Some(total), }, Self::Complete(true) => Self::Complete(true), } } pub fn add_total(&mut self, total: u64) { if let Self::Progress { done, total: Some(old), } = *self { *self = Self::Progress { done, total: Some(old + total), }; } else { self.set_total(total) } } pub fn complete(&mut self) { *self = Self::Complete(true); } } impl std::ops::Add for Progress { type Output = Self; fn add(self, rhs: u64) -> Self::Output { match self { Self::Complete(false) => Self::Progress { done: rhs, total: None, }, Self::Progress { done, total } => { let mut done = done + rhs; if let Some(total) = total { if done > total { done = total; } } Self::Progress { done, total } } Self::Complete(true) => Self::Complete(true), } } } impl std::ops::AddAssign for Progress { fn add_assign(&mut self, rhs: u64) { *self = *self + rhs; } } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct NamedProgress { pub name: InternedString, pub progress: Progress, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct FullProgress { pub overall: Progress, pub phases: Vec, } impl FullProgress { pub fn new() -> Self { Self { overall: Progress::new(), phases: Vec::new(), } } } pub struct FullProgressTracker { overall: Arc>, overall_recv: watch::Receiver, phases: InOMap>, new_phase: ( mpsc::UnboundedSender<(InternedString, watch::Receiver)>, mpsc::UnboundedReceiver<(InternedString, watch::Receiver)>, ), } impl FullProgressTracker { pub fn new() -> Self { let (overall, overall_recv) = watch::channel(Progress::new()); Self { overall: Arc::new(overall), overall_recv, phases: InOMap::new(), new_phase: mpsc::unbounded_channel(), } } fn fill_phases(&mut self) -> bool { let mut changed = false; while let Ok((name, phase)) = self.new_phase.1.try_recv() { self.phases.insert(name, phase); changed = true; } changed } pub fn snapshot(&mut self) -> FullProgress { self.fill_phases(); FullProgress { overall: *self.overall.borrow(), phases: self .phases .iter() .map(|(name, progress)| NamedProgress { name: name.clone(), progress: *progress.borrow(), }) .collect(), } } pub async fn changed(&mut self) { if self.fill_phases() { return; } let phases = self .phases .iter_mut() .map(|(_, p)| Box::pin(p.changed())) .collect_vec(); tokio::select! { _ = self.overall_recv.changed() => (), _ = futures::future::select_all(phases) => (), } } pub fn handle(&self) -> FullProgressTrackerHandle { FullProgressTrackerHandle { overall: self.overall.clone(), new_phase: self.new_phase.0.clone(), } } pub fn sync_to_db( mut self, db: PatchDb, deref: DerefFn, min_interval: Option, ) -> impl Future> + 'static where DerefFn: Fn(&mut DatabaseModel) -> Option<&mut Model> + 'static, for<'a> &'a DerefFn: UnwindSafe + Send, { async move { loop { let progress = self.snapshot(); if db .mutate(|v| { if let Some(p) = deref(v) { p.ser(&progress)?; Ok(false) } else { Ok(true) } }) .await? { break; } tokio::join!(self.changed(), async { if let Some(interval) = min_interval { tokio::time::sleep(interval).await } else { futures::future::ready(()).await } }); } Ok(()) } } } #[derive(Clone)] pub struct FullProgressTrackerHandle { overall: Arc>, new_phase: mpsc::UnboundedSender<(InternedString, watch::Receiver)>, } impl FullProgressTrackerHandle { pub fn add_phase( &self, name: InternedString, overall_contribution: Option, ) -> PhaseProgressTrackerHandle { if let Some(overall_contribution) = overall_contribution { self.overall .send_modify(|o| o.add_total(overall_contribution)); } let (send, recv) = watch::channel(Progress::new()); let _ = self.new_phase.send((name, recv)); PhaseProgressTrackerHandle { overall: self.overall.clone(), overall_contribution, contributed: 0, progress: send, } } pub fn complete(&self) { self.overall.send_modify(|o| o.complete()); } } pub struct PhaseProgressTrackerHandle { overall: Arc>, overall_contribution: Option, contributed: u64, progress: watch::Sender, } impl PhaseProgressTrackerHandle { fn update_overall(&mut self) { if let Some(overall_contribution) = self.overall_contribution { let contribution = match *self.progress.borrow() { Progress::Complete(true) => overall_contribution, Progress::Progress { done, total: Some(total), } => ((done as f64 / total as f64) * overall_contribution as f64) as u64, _ => 0, }; if contribution > self.contributed { self.overall .send_modify(|o| *o += contribution - self.contributed); self.contributed = contribution; } } } pub fn set_done(&mut self, done: u64) { self.progress.send_modify(|p| p.set_done(done)); self.update_overall(); } pub fn set_total(&mut self, total: u64) { self.progress.send_modify(|p| p.set_total(total)); self.update_overall(); } pub fn add_total(&mut self, total: u64) { self.progress.send_modify(|p| p.add_total(total)); self.update_overall(); } pub fn complete(&mut self) { self.progress.send_modify(|p| p.complete()); self.update_overall(); } } impl std::ops::AddAssign for PhaseProgressTrackerHandle { fn add_assign(&mut self, rhs: u64) { self.progress.send_modify(|p| *p += rhs); self.update_overall(); } } #[pin_project::pin_project] pub struct ProgressTrackerWriter { #[pin] writer: W, progress: PhaseProgressTrackerHandle, } impl ProgressTrackerWriter { pub fn new(writer: W, progress: PhaseProgressTrackerHandle) -> Self { Self { writer, progress } } pub fn into_inner(self) -> (W, PhaseProgressTrackerHandle) { (self.writer, self.progress) } } impl AsyncWrite for ProgressTrackerWriter { fn poll_write( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8], ) -> std::task::Poll> { let this = self.project(); match this.writer.poll_write(cx, buf) { std::task::Poll::Ready(Ok(n)) => { *this.progress += n as u64; std::task::Poll::Ready(Ok(n)) } a => a, } } fn poll_flush( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { self.project().writer.poll_flush(cx) } fn poll_shutdown( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { self.project().writer.poll_shutdown(cx) } fn is_write_vectored(&self) -> bool { self.writer.is_write_vectored() } fn poll_write_vectored( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, bufs: &[std::io::IoSlice<'_>], ) -> std::task::Poll> { self.project().writer.poll_write_vectored(cx, bufs) } } impl AsyncSeek for ProgressTrackerWriter { fn start_seek( self: std::pin::Pin<&mut Self>, position: std::io::SeekFrom, ) -> std::io::Result<()> { self.project().writer.start_seek(position) } fn poll_complete( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.project(); match this.writer.poll_complete(cx) { std::task::Poll::Ready(Ok(n)) => { this.progress.set_done(n); std::task::Poll::Ready(Ok(n)) } a => a, } } } pub struct PhasedProgressBar { multi: MultiProgress, overall: ProgressBar, phases: InOMap, } impl PhasedProgressBar { pub fn new(name: &str) -> Self { let multi = MultiProgress::new(); Self { overall: multi.add( ProgressBar::new(0) .with_style(SPINNER.clone()) .with_message(name.to_owned()), ), multi, phases: InOMap::new(), } } pub fn update(&mut self, progress: &FullProgress) { for phase in progress.phases.iter() { if !self.phases.contains_key(&phase.name) { self.phases.insert( phase.name.clone(), self.multi .add(ProgressBar::new(0).with_style(SPINNER.clone())) .with_message((&*phase.name).to_owned()), ); } } progress.overall.update_bar(&self.overall); for (name, bar) in self.phases.iter() { if let Some(progress) = progress.phases.iter().find_map(|p| { if &p.name == name { Some(p.progress) } else { None } }) { progress.update_bar(bar); } } } }