use std::io::SeekFrom; use std::path::Path; use std::pin::Pin; use std::sync::Arc; use std::task::Poll; use std::time::Duration; use axum::body::Body; use axum::extract::Request; use axum::response::Response; use bytes::Bytes; use futures::{FutureExt, Stream, StreamExt, ready}; use http::header::CONTENT_LENGTH; use http::{HeaderMap, StatusCode}; use imbl_value::InternedString; use tokio::io::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::watch; use crate::context::RpcContext; use crate::prelude::*; use crate::progress::{PhaseProgressTrackerHandle, ProgressUnits}; use crate::rpc_continuations::{Guid, RpcContinuation}; use crate::s9pk::merkle_archive::source::ArchiveSource; use crate::s9pk::merkle_archive::source::multi_cursor_file::{FileCursor, MultiCursorFile}; use crate::util::direct_io::DirectIoFile; use crate::util::io::{TmpDir, create_file}; pub async fn upload( ctx: &RpcContext, session: Option, progress: PhaseProgressTrackerHandle, ) -> Result<(Guid, UploadingFile), Error> { let guid = Guid::new(); let (mut handle, file) = UploadingFile::new(progress).await?; ctx.rpc_continuations .add( guid.clone(), RpcContinuation::rest_authed( ctx, session, |request| async move { handle.upload(request).await; Response::builder() .status(StatusCode::NO_CONTENT) .body(Body::empty()) .with_kind(ErrorKind::Network) }, Duration::from_secs(30), ), ) .await; Ok((guid, file)) } struct Progress { tracker: PhaseProgressTrackerHandle, expected_size: Option, written: u64, complete: bool, error: Option, } impl Progress { fn handle_error(&mut self, e: &std::io::Error) -> bool { if self.error.is_none() { self.error = Some(Error::new(eyre!("{e}"), ErrorKind::Network)); true } else { false } } async fn expected_size(watch: &mut watch::Receiver) -> Option { watch .wait_for(|progress| progress.error.is_some() || progress.expected_size.is_some()) .await .ok() .and_then(|a| a.expected_size) } async fn ready_for(watch: &mut watch::Receiver, size: u64) -> Result<(), Error> { match &*watch .wait_for(|progress| { progress.error.is_some() || progress.written >= size || progress.expected_size.map_or(false, |e| e < size) }) .await .map_err(|_| { Error::new( eyre!("failed to determine upload progress"), ErrorKind::Network, ) })? { Progress { error: Some(e), .. } => Err(e.clone_output()), Progress { expected_size: Some(e), .. } if *e < size => Err(Error::new( eyre!("file size is less than requested"), ErrorKind::Network, )), _ => Ok(()), } } async fn ready(watch: &mut watch::Receiver) -> Result<(), Error> { match &*watch .wait_for(|progress| progress.error.is_some() || progress.complete) .await .map_err(|_| { Error::new( eyre!("failed to determine upload progress"), ErrorKind::Network, ) })? { Progress { error: Some(e), .. } => Err(e.clone_output()), _ => Ok(()), } } fn complete(&mut self) -> bool { let mut changed = !self.complete; self.tracker.complete(); changed |= match self { Self { expected_size: Some(size), written, .. } if *written == *size => false, Self { expected_size: Some(size), written, error, .. } if *written > *size && error.is_none() => { *error = Some(Error::new( eyre!("Too many bytes received"), ErrorKind::Network, )); true } Self { error, expected_size: Some(_), .. } if error.is_none() => { *error = Some(Error::new( eyre!("Connection closed or timed out before full file received"), ErrorKind::Network, )); true } Self { expected_size, written, .. } if expected_size.is_none() => { *expected_size = Some(*written); true } _ => false, }; self.complete = true; changed } } #[derive(Clone)] pub struct UploadingFile { tmp_dir: Option>, file: MultiCursorFile, progress: watch::Receiver, } impl UploadingFile { pub async fn with_path( path: impl AsRef, mut progress: PhaseProgressTrackerHandle, ) -> Result<(UploadHandle, Self), Error> { progress.set_units(Some(ProgressUnits::Bytes)); let progress = watch::channel(Progress { tracker: progress, expected_size: None, written: 0, error: None, complete: false, }); let file = create_file(path).await?; let multi_cursor = MultiCursorFile::open(&file).await?; let direct_file = DirectIoFile::from_tokio_file(file).await?; let uploading = Self { tmp_dir: None, file: multi_cursor, progress: progress.1, }; Ok(( UploadHandle { tmp_dir: None, file: direct_file, progress: progress.0, last_synced: 0, }, uploading, )) } pub async fn new(progress: PhaseProgressTrackerHandle) -> Result<(UploadHandle, Self), Error> { let tmp_dir = Arc::new(TmpDir::new().await?); let (mut handle, mut file) = Self::with_path(tmp_dir.join("upload.tmp"), progress).await?; handle.tmp_dir = Some(tmp_dir.clone()); file.tmp_dir = Some(tmp_dir); Ok((handle, file)) } pub async fn wait_for_complete(&self) -> Result<(), Error> { Progress::ready(&mut self.progress.clone()).await } pub async fn delete(self) -> Result<(), Error> { if let Some(Ok(tmp_dir)) = self.tmp_dir.map(Arc::try_unwrap) { tmp_dir.delete().await?; } Ok(()) } } impl ArchiveSource for UploadingFile { type FetchReader = ::FetchReader; type FetchAllReader = UploadingFileReader; async fn size(&self) -> Option { Progress::expected_size(&mut self.progress.clone()).await } async fn fetch_all(&self) -> Result { let mut file = self.file.cursor().await?; file.seek(SeekFrom::Start(0)).await?; Ok(UploadingFileReader { tmp_dir: self.tmp_dir.clone(), file, position: 0, to_seek: None, progress: self.progress.clone(), }) } async fn fetch(&self, position: u64, size: u64) -> Result { Progress::ready_for(&mut self.progress.clone(), position + size).await?; self.file.fetch(position, size).await } } #[pin_project::pin_project(project = UploadingFileReaderProjection)] pub struct UploadingFileReader { tmp_dir: Option>, position: u64, to_seek: Option, #[pin] file: FileCursor, progress: watch::Receiver, } impl<'a> UploadingFileReaderProjection<'a> { fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Result { let ready = Progress::ready(&mut *self.progress); tokio::pin!(ready); Ok(ready .poll_unpin(cx) .map_err(|e| std::io::Error::other(e.source))? .is_ready()) } fn poll_ready_for( &mut self, cx: &mut std::task::Context<'_>, size: u64, ) -> Result { let ready = Progress::ready_for(&mut *self.progress, size); tokio::pin!(ready); Ok(ready .poll_unpin(cx) .map_err(|e| std::io::Error::other(e.source))? .is_ready()) } } impl AsyncRead for UploadingFileReader { fn poll_read( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { let mut this = self.project(); let position = *this.position; if this.poll_ready(cx)? || this.poll_ready_for(cx, position + buf.remaining() as u64)? { let start = buf.filled().len(); let res = this.file.poll_read(cx, buf); *this.position += (buf.filled().len() - start) as u64; res } else { Poll::Pending } } } impl AsyncSeek for UploadingFileReader { fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> std::io::Result<()> { let this = self.project(); *this.to_seek = Some(position); Ok(()) } fn poll_complete( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { let mut this = self.project(); if let Some(to_seek) = *this.to_seek { let size = match to_seek { SeekFrom::Current(n) => (*this.position as i64 + n) as u64, SeekFrom::Start(n) => n, SeekFrom::End(n) => { let expected_size = this.progress.borrow().expected_size; match expected_size { Some(end) => (end as i64 + n) as u64, None => { if !this.poll_ready(cx)? { return Poll::Pending; } (this.progress.borrow().expected_size.ok_or_else(|| { std::io::Error::new( std::io::ErrorKind::Other, eyre!("upload maked complete without expected size"), ) })? as i64 + n) as u64 } } } }; if !this.poll_ready_for(cx, size)? { return Poll::Pending; } } if let Some(seek) = this.to_seek.take() { this.file.as_mut().start_seek(seek)?; } *this.position = ready!(this.file.as_mut().poll_complete(cx)?); Poll::Ready(Ok(*this.position)) } } #[pin_project::pin_project(PinnedDrop)] pub struct UploadHandle { tmp_dir: Option>, #[pin] file: DirectIoFile, progress: watch::Sender, last_synced: u64, } impl UploadHandle { pub async fn upload(&mut self, request: Request) { self.process_headers(request.headers()); self.process_body(request.into_body().into_data_stream()) .await; self.progress.send_if_modified(|p| p.complete()); } pub async fn download(&mut self, response: reqwest::Response) { self.process_headers(response.headers()); self.process_body(response.bytes_stream()).await; self.progress.send_if_modified(|p| p.complete()); } fn process_headers(&mut self, headers: &HeaderMap) { if let Some(content_length) = headers .get(CONTENT_LENGTH) .and_then(|a| a.to_str().log_err()) .and_then(|a| a.parse::().log_err()) { self.progress.send_modify(|p| { p.expected_size = Some(content_length); p.tracker.set_total(content_length); }); } } async fn process_body>>( &mut self, mut body: impl Stream> + Unpin, ) { while let Some(next) = body.next().await { if let Err(e) = async { self.write_all( &next.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?, ) .await?; Ok(()) } .await { self.progress.send_if_modified(|p| p.handle_error(&e)); break; } } if let Err(e) = self.file.sync_all().await { self.progress.send_if_modified(|p| p.handle_error(&e)); } // Update progress with final synced bytes self.update_sync_progress(); } fn update_sync_progress(&mut self) { let synced = self.file.bytes_synced(); let delta = synced - self.last_synced; if delta > 0 { self.last_synced = synced; self.progress.send_modify(|p| { p.written += delta; p.tracker += delta; }); } } } #[pin_project::pinned_drop] impl PinnedDrop for UploadHandle { fn drop(self: Pin<&mut Self>) { let this = self.project(); this.progress.send_if_modified(|p| p.complete()); } } impl AsyncWrite for UploadHandle { fn poll_write( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8], ) -> Poll> { let this = self.project(); // Update progress based on bytes actually flushed to disk let synced = this.file.bytes_synced(); let delta = synced - *this.last_synced; if delta > 0 { *this.last_synced = synced; this.progress.send_modify(|p| { p.written += delta; p.tracker += delta; }); } match this.file.poll_write(cx, buf) { Poll::Ready(Err(e)) => { this.progress .send_if_modified(|progress| progress.handle_error(&e)); Poll::Ready(Err(e)) } a => a, } } fn poll_flush( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { let this = self.project(); match this.file.poll_flush(cx) { Poll::Ready(Err(e)) => { this.progress .send_if_modified(|progress| progress.handle_error(&e)); Poll::Ready(Err(e)) } a => a, } } fn poll_shutdown( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { let this = self.project(); match this.file.poll_shutdown(cx) { Poll::Ready(Err(e)) => { this.progress .send_if_modified(|progress| progress.handle_error(&e)); Poll::Ready(Err(e)) } a => a, } } }