mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-03-26 18:31:52 +00:00
Implements DirectIoFile for faster package uploads by bypassing page cache. Refactors RpcContinuations to support graceful WebSocket shutdown via broadcast signal, improving stability during daemon restart.
468 lines
15 KiB
Rust
468 lines
15 KiB
Rust
use std::io::SeekFrom;
|
|
use std::path::Path;
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
use std::task::Poll;
|
|
use std::time::Duration;
|
|
|
|
use axum::body::Body;
|
|
use axum::extract::Request;
|
|
use axum::response::Response;
|
|
use bytes::Bytes;
|
|
use futures::{FutureExt, Stream, StreamExt, ready};
|
|
use http::header::CONTENT_LENGTH;
|
|
use http::{HeaderMap, StatusCode};
|
|
use imbl_value::InternedString;
|
|
use tokio::io::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt};
|
|
use tokio::sync::watch;
|
|
|
|
use crate::context::RpcContext;
|
|
use crate::prelude::*;
|
|
use crate::progress::{PhaseProgressTrackerHandle, ProgressUnits};
|
|
use crate::rpc_continuations::{Guid, RpcContinuation};
|
|
use crate::s9pk::merkle_archive::source::ArchiveSource;
|
|
use crate::s9pk::merkle_archive::source::multi_cursor_file::{FileCursor, MultiCursorFile};
|
|
use crate::util::direct_io::DirectIoFile;
|
|
use crate::util::io::{TmpDir, create_file};
|
|
|
|
pub async fn upload(
|
|
ctx: &RpcContext,
|
|
session: Option<InternedString>,
|
|
progress: PhaseProgressTrackerHandle,
|
|
) -> Result<(Guid, UploadingFile), Error> {
|
|
let guid = Guid::new();
|
|
let (mut handle, file) = UploadingFile::new(progress).await?;
|
|
ctx.rpc_continuations
|
|
.add(
|
|
guid.clone(),
|
|
RpcContinuation::rest_authed(
|
|
ctx,
|
|
session,
|
|
|request| async move {
|
|
handle.upload(request).await;
|
|
|
|
Response::builder()
|
|
.status(StatusCode::NO_CONTENT)
|
|
.body(Body::empty())
|
|
.with_kind(ErrorKind::Network)
|
|
},
|
|
Duration::from_secs(30),
|
|
),
|
|
)
|
|
.await;
|
|
Ok((guid, file))
|
|
}
|
|
|
|
struct Progress {
|
|
tracker: PhaseProgressTrackerHandle,
|
|
expected_size: Option<u64>,
|
|
written: u64,
|
|
complete: bool,
|
|
error: Option<Error>,
|
|
}
|
|
impl Progress {
|
|
fn handle_error(&mut self, e: &std::io::Error) -> bool {
|
|
if self.error.is_none() {
|
|
self.error = Some(Error::new(eyre!("{e}"), ErrorKind::Network));
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
async fn expected_size(watch: &mut watch::Receiver<Self>) -> Option<u64> {
|
|
watch
|
|
.wait_for(|progress| progress.error.is_some() || progress.expected_size.is_some())
|
|
.await
|
|
.ok()
|
|
.and_then(|a| a.expected_size)
|
|
}
|
|
async fn ready_for(watch: &mut watch::Receiver<Self>, size: u64) -> Result<(), Error> {
|
|
match &*watch
|
|
.wait_for(|progress| {
|
|
progress.error.is_some()
|
|
|| progress.written >= size
|
|
|| progress.expected_size.map_or(false, |e| e < size)
|
|
})
|
|
.await
|
|
.map_err(|_| {
|
|
Error::new(
|
|
eyre!("failed to determine upload progress"),
|
|
ErrorKind::Network,
|
|
)
|
|
})? {
|
|
Progress { error: Some(e), .. } => Err(e.clone_output()),
|
|
Progress {
|
|
expected_size: Some(e),
|
|
..
|
|
} if *e < size => Err(Error::new(
|
|
eyre!("file size is less than requested"),
|
|
ErrorKind::Network,
|
|
)),
|
|
_ => Ok(()),
|
|
}
|
|
}
|
|
async fn ready(watch: &mut watch::Receiver<Self>) -> Result<(), Error> {
|
|
match &*watch
|
|
.wait_for(|progress| progress.error.is_some() || progress.complete)
|
|
.await
|
|
.map_err(|_| {
|
|
Error::new(
|
|
eyre!("failed to determine upload progress"),
|
|
ErrorKind::Network,
|
|
)
|
|
})? {
|
|
Progress { error: Some(e), .. } => Err(e.clone_output()),
|
|
_ => Ok(()),
|
|
}
|
|
}
|
|
fn complete(&mut self) -> bool {
|
|
let mut changed = !self.complete;
|
|
self.tracker.complete();
|
|
changed |= match self {
|
|
Self {
|
|
expected_size: Some(size),
|
|
written,
|
|
..
|
|
} if *written == *size => false,
|
|
Self {
|
|
expected_size: Some(size),
|
|
written,
|
|
error,
|
|
..
|
|
} if *written > *size && error.is_none() => {
|
|
*error = Some(Error::new(
|
|
eyre!("Too many bytes received"),
|
|
ErrorKind::Network,
|
|
));
|
|
true
|
|
}
|
|
Self {
|
|
error,
|
|
expected_size: Some(_),
|
|
..
|
|
} if error.is_none() => {
|
|
*error = Some(Error::new(
|
|
eyre!("Connection closed or timed out before full file received"),
|
|
ErrorKind::Network,
|
|
));
|
|
true
|
|
}
|
|
Self {
|
|
expected_size,
|
|
written,
|
|
..
|
|
} if expected_size.is_none() => {
|
|
*expected_size = Some(*written);
|
|
true
|
|
}
|
|
_ => false,
|
|
};
|
|
self.complete = true;
|
|
changed
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct UploadingFile {
|
|
tmp_dir: Option<Arc<TmpDir>>,
|
|
file: MultiCursorFile,
|
|
progress: watch::Receiver<Progress>,
|
|
}
|
|
impl UploadingFile {
|
|
pub async fn with_path(
|
|
path: impl AsRef<Path>,
|
|
mut progress: PhaseProgressTrackerHandle,
|
|
) -> Result<(UploadHandle, Self), Error> {
|
|
progress.set_units(Some(ProgressUnits::Bytes));
|
|
let progress = watch::channel(Progress {
|
|
tracker: progress,
|
|
expected_size: None,
|
|
written: 0,
|
|
error: None,
|
|
complete: false,
|
|
});
|
|
let file = create_file(path).await?;
|
|
let multi_cursor = MultiCursorFile::open(&file).await?;
|
|
let direct_file = DirectIoFile::from_tokio_file(file).await?;
|
|
let uploading = Self {
|
|
tmp_dir: None,
|
|
file: multi_cursor,
|
|
progress: progress.1,
|
|
};
|
|
Ok((
|
|
UploadHandle {
|
|
tmp_dir: None,
|
|
file: direct_file,
|
|
progress: progress.0,
|
|
last_synced: 0,
|
|
},
|
|
uploading,
|
|
))
|
|
}
|
|
pub async fn new(progress: PhaseProgressTrackerHandle) -> Result<(UploadHandle, Self), Error> {
|
|
let tmp_dir = Arc::new(TmpDir::new().await?);
|
|
let (mut handle, mut file) = Self::with_path(tmp_dir.join("upload.tmp"), progress).await?;
|
|
handle.tmp_dir = Some(tmp_dir.clone());
|
|
file.tmp_dir = Some(tmp_dir);
|
|
Ok((handle, file))
|
|
}
|
|
pub async fn wait_for_complete(&self) -> Result<(), Error> {
|
|
Progress::ready(&mut self.progress.clone()).await
|
|
}
|
|
pub async fn delete(self) -> Result<(), Error> {
|
|
if let Some(Ok(tmp_dir)) = self.tmp_dir.map(Arc::try_unwrap) {
|
|
tmp_dir.delete().await?;
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
impl ArchiveSource for UploadingFile {
|
|
type FetchReader = <MultiCursorFile as ArchiveSource>::FetchReader;
|
|
type FetchAllReader = UploadingFileReader;
|
|
async fn size(&self) -> Option<u64> {
|
|
Progress::expected_size(&mut self.progress.clone()).await
|
|
}
|
|
async fn fetch_all(&self) -> Result<Self::FetchAllReader, Error> {
|
|
let mut file = self.file.cursor().await?;
|
|
file.seek(SeekFrom::Start(0)).await?;
|
|
Ok(UploadingFileReader {
|
|
tmp_dir: self.tmp_dir.clone(),
|
|
file,
|
|
position: 0,
|
|
to_seek: None,
|
|
progress: self.progress.clone(),
|
|
})
|
|
}
|
|
async fn fetch(&self, position: u64, size: u64) -> Result<Self::FetchReader, Error> {
|
|
Progress::ready_for(&mut self.progress.clone(), position + size).await?;
|
|
self.file.fetch(position, size).await
|
|
}
|
|
}
|
|
|
|
#[pin_project::pin_project(project = UploadingFileReaderProjection)]
|
|
pub struct UploadingFileReader {
|
|
tmp_dir: Option<Arc<TmpDir>>,
|
|
position: u64,
|
|
to_seek: Option<SeekFrom>,
|
|
#[pin]
|
|
file: FileCursor,
|
|
progress: watch::Receiver<Progress>,
|
|
}
|
|
impl<'a> UploadingFileReaderProjection<'a> {
|
|
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Result<bool, std::io::Error> {
|
|
let ready = Progress::ready(&mut *self.progress);
|
|
tokio::pin!(ready);
|
|
Ok(ready
|
|
.poll_unpin(cx)
|
|
.map_err(|e| std::io::Error::other(e.source))?
|
|
.is_ready())
|
|
}
|
|
fn poll_ready_for(
|
|
&mut self,
|
|
cx: &mut std::task::Context<'_>,
|
|
size: u64,
|
|
) -> Result<bool, std::io::Error> {
|
|
let ready = Progress::ready_for(&mut *self.progress, size);
|
|
tokio::pin!(ready);
|
|
Ok(ready
|
|
.poll_unpin(cx)
|
|
.map_err(|e| std::io::Error::other(e.source))?
|
|
.is_ready())
|
|
}
|
|
}
|
|
impl AsyncRead for UploadingFileReader {
|
|
fn poll_read(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
buf: &mut tokio::io::ReadBuf<'_>,
|
|
) -> Poll<std::io::Result<()>> {
|
|
let mut this = self.project();
|
|
|
|
let position = *this.position;
|
|
if this.poll_ready(cx)? || this.poll_ready_for(cx, position + buf.remaining() as u64)? {
|
|
let start = buf.filled().len();
|
|
let res = this.file.poll_read(cx, buf);
|
|
*this.position += (buf.filled().len() - start) as u64;
|
|
res
|
|
} else {
|
|
Poll::Pending
|
|
}
|
|
}
|
|
}
|
|
impl AsyncSeek for UploadingFileReader {
|
|
fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> std::io::Result<()> {
|
|
let this = self.project();
|
|
*this.to_seek = Some(position);
|
|
Ok(())
|
|
}
|
|
fn poll_complete(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> Poll<std::io::Result<u64>> {
|
|
let mut this = self.project();
|
|
if let Some(to_seek) = *this.to_seek {
|
|
let size = match to_seek {
|
|
SeekFrom::Current(n) => (*this.position as i64 + n) as u64,
|
|
SeekFrom::Start(n) => n,
|
|
SeekFrom::End(n) => {
|
|
let expected_size = this.progress.borrow().expected_size;
|
|
match expected_size {
|
|
Some(end) => (end as i64 + n) as u64,
|
|
None => {
|
|
if !this.poll_ready(cx)? {
|
|
return Poll::Pending;
|
|
}
|
|
(this.progress.borrow().expected_size.ok_or_else(|| {
|
|
std::io::Error::new(
|
|
std::io::ErrorKind::Other,
|
|
eyre!("upload maked complete without expected size"),
|
|
)
|
|
})? as i64
|
|
+ n) as u64
|
|
}
|
|
}
|
|
}
|
|
};
|
|
if !this.poll_ready_for(cx, size)? {
|
|
return Poll::Pending;
|
|
}
|
|
}
|
|
if let Some(seek) = this.to_seek.take() {
|
|
this.file.as_mut().start_seek(seek)?;
|
|
}
|
|
*this.position = ready!(this.file.as_mut().poll_complete(cx)?);
|
|
Poll::Ready(Ok(*this.position))
|
|
}
|
|
}
|
|
|
|
#[pin_project::pin_project(PinnedDrop)]
|
|
pub struct UploadHandle {
|
|
tmp_dir: Option<Arc<TmpDir>>,
|
|
#[pin]
|
|
file: DirectIoFile,
|
|
progress: watch::Sender<Progress>,
|
|
last_synced: u64,
|
|
}
|
|
impl UploadHandle {
|
|
pub async fn upload(&mut self, request: Request) {
|
|
self.process_headers(request.headers());
|
|
self.process_body(request.into_body().into_data_stream())
|
|
.await;
|
|
self.progress.send_if_modified(|p| p.complete());
|
|
}
|
|
pub async fn download(&mut self, response: reqwest::Response) {
|
|
self.process_headers(response.headers());
|
|
self.process_body(response.bytes_stream()).await;
|
|
self.progress.send_if_modified(|p| p.complete());
|
|
}
|
|
fn process_headers(&mut self, headers: &HeaderMap) {
|
|
if let Some(content_length) = headers
|
|
.get(CONTENT_LENGTH)
|
|
.and_then(|a| a.to_str().log_err())
|
|
.and_then(|a| a.parse::<u64>().log_err())
|
|
{
|
|
self.progress.send_modify(|p| {
|
|
p.expected_size = Some(content_length);
|
|
p.tracker.set_total(content_length);
|
|
});
|
|
}
|
|
}
|
|
async fn process_body<E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>>(
|
|
&mut self,
|
|
mut body: impl Stream<Item = Result<Bytes, E>> + Unpin,
|
|
) {
|
|
while let Some(next) = body.next().await {
|
|
if let Err(e) = async {
|
|
self.write_all(
|
|
&next.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
|
|
)
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
.await
|
|
{
|
|
self.progress.send_if_modified(|p| p.handle_error(&e));
|
|
break;
|
|
}
|
|
}
|
|
if let Err(e) = self.file.sync_all().await {
|
|
self.progress.send_if_modified(|p| p.handle_error(&e));
|
|
}
|
|
// Update progress with final synced bytes
|
|
self.update_sync_progress();
|
|
}
|
|
fn update_sync_progress(&mut self) {
|
|
let synced = self.file.bytes_synced();
|
|
let delta = synced - self.last_synced;
|
|
if delta > 0 {
|
|
self.last_synced = synced;
|
|
self.progress.send_modify(|p| {
|
|
p.written += delta;
|
|
p.tracker += delta;
|
|
});
|
|
}
|
|
}
|
|
}
|
|
#[pin_project::pinned_drop]
|
|
impl PinnedDrop for UploadHandle {
|
|
fn drop(self: Pin<&mut Self>) {
|
|
let this = self.project();
|
|
this.progress.send_if_modified(|p| p.complete());
|
|
}
|
|
}
|
|
impl AsyncWrite for UploadHandle {
|
|
fn poll_write(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<Result<usize, std::io::Error>> {
|
|
let this = self.project();
|
|
// Update progress based on bytes actually flushed to disk
|
|
let synced = this.file.bytes_synced();
|
|
let delta = synced - *this.last_synced;
|
|
if delta > 0 {
|
|
*this.last_synced = synced;
|
|
this.progress.send_modify(|p| {
|
|
p.written += delta;
|
|
p.tracker += delta;
|
|
});
|
|
}
|
|
match this.file.poll_write(cx, buf) {
|
|
Poll::Ready(Err(e)) => {
|
|
this.progress
|
|
.send_if_modified(|progress| progress.handle_error(&e));
|
|
Poll::Ready(Err(e))
|
|
}
|
|
a => a,
|
|
}
|
|
}
|
|
fn poll_flush(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> Poll<Result<(), std::io::Error>> {
|
|
let this = self.project();
|
|
match this.file.poll_flush(cx) {
|
|
Poll::Ready(Err(e)) => {
|
|
this.progress
|
|
.send_if_modified(|progress| progress.handle_error(&e));
|
|
Poll::Ready(Err(e))
|
|
}
|
|
a => a,
|
|
}
|
|
}
|
|
fn poll_shutdown(
|
|
self: Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> Poll<Result<(), std::io::Error>> {
|
|
let this = self.project();
|
|
match this.file.poll_shutdown(cx) {
|
|
Poll::Ready(Err(e)) => {
|
|
this.progress
|
|
.send_if_modified(|progress| progress.handle_error(&e));
|
|
Poll::Ready(Err(e))
|
|
}
|
|
a => a,
|
|
}
|
|
}
|
|
}
|