diff --git a/core/models/src/errors.rs b/core/models/src/errors.rs index 8bbc705ee..ee6b0ae12 100644 --- a/core/models/src/errors.rs +++ b/core/models/src/errors.rs @@ -490,6 +490,7 @@ where { fn with_kind(self, kind: ErrorKind) -> Result; fn with_ctx (ErrorKind, D), D: Display>(self, f: F) -> Result; + fn log_err(self) -> Option; } impl ResultExt for Result where @@ -516,6 +517,18 @@ where } }) } + + fn log_err(self) -> Option { + match self { + Ok(a) => Some(a), + Err(e) => { + let e: color_eyre::eyre::Error = e.into(); + tracing::error!("{e}"); + tracing::debug!("{e:?}"); + None + } + } + } } impl ResultExt for Result { fn with_kind(self, kind: ErrorKind) -> Result { @@ -539,6 +552,17 @@ impl ResultExt for Result { } }) } + + fn log_err(self) -> Option { + match self { + Ok(a) => Some(a), + Err(e) => { + tracing::error!("{e}"); + tracing::debug!("{e:?}"); + None + } + } + } } pub trait OptionExt diff --git a/core/startos/src/install/mod.rs b/core/startos/src/install/mod.rs index 8d4f65312..7a545a3aa 100644 --- a/core/startos/src/install/mod.rs +++ b/core/startos/src/install/mod.rs @@ -157,7 +157,7 @@ pub async fn install( .services .install( ctx.clone(), - || asset.deserialize_s9pk(ctx.client.clone()), + || asset.deserialize_s9pk_buffered(ctx.client.clone()), None::, None, ) diff --git a/core/startos/src/registry/asset.rs b/core/startos/src/registry/asset.rs index 7697a0c99..fb6dd59fc 100644 --- a/core/startos/src/registry/asset.rs +++ b/core/startos/src/registry/asset.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; use chrono::{DateTime, Utc}; +use helpers::NonDetachingJoinHandle; use reqwest::Client; use serde::{Deserialize, Serialize}; use tokio::io::AsyncWrite; @@ -14,8 +15,9 @@ use crate::registry::signer::commitment::{Commitment, Digestable}; use crate::registry::signer::sign::{AnySignature, AnyVerifyingKey}; use crate::registry::signer::AcceptSigners; use crate::s9pk::merkle_archive::source::http::HttpSource; -use crate::s9pk::merkle_archive::source::Section; +use crate::s9pk::merkle_archive::source::{ArchiveSource, Section}; use crate::s9pk::S9pk; +use crate::upload::UploadingFile; #[derive(Debug, Deserialize, Serialize, TS)] #[serde(rename_all = "camelCase")] @@ -70,4 +72,42 @@ impl RegistryAsset { ) .await } + pub async fn deserialize_s9pk_buffered( + &self, + client: Client, + ) -> Result>>, Error> { + S9pk::deserialize( + &Arc::new(BufferedHttpSource::new(client, self.url.clone()).await?), + Some(&self.commitment), + ) + .await + } +} + +pub struct BufferedHttpSource { + _download: NonDetachingJoinHandle<()>, + file: UploadingFile, +} +impl BufferedHttpSource { + pub async fn new(client: Client, url: Url) -> Result { + let (mut handle, file) = UploadingFile::new().await?; + let response = client.get(url).send().await?; + Ok(Self { + _download: tokio::spawn(async move { handle.download(response).await }).into(), + file, + }) + } +} +impl ArchiveSource for BufferedHttpSource { + type FetchReader = ::FetchReader; + type FetchAllReader = ::FetchAllReader; + async fn size(&self) -> Option { + self.file.size().await + } + async fn fetch_all(&self) -> Result { + self.file.fetch_all().await + } + async fn fetch(&self, position: u64, size: u64) -> Result { + self.file.fetch(position, size).await + } } diff --git a/core/startos/src/upload.rs b/core/startos/src/upload.rs index 4735a63fb..20f294400 100644 --- a/core/startos/src/upload.rs +++ b/core/startos/src/upload.rs @@ -5,10 +5,12 @@ use std::task::Poll; use std::time::Duration; use axum::body::Body; +use axum::extract::Request; use axum::response::Response; -use futures::{ready, FutureExt, StreamExt}; +use bytes::Bytes; +use futures::{ready, FutureExt, Stream, StreamExt}; use http::header::CONTENT_LENGTH; -use http::StatusCode; +use http::{HeaderMap, StatusCode}; use imbl_value::InternedString; use tokio::fs::File; use tokio::io::{AsyncRead, AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt}; @@ -34,51 +36,7 @@ pub async fn upload( ctx, session, |request| async move { - let headers = request.headers(); - let content_length = match headers.get(CONTENT_LENGTH).map(|a| a.to_str()) { - None => { - return Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::from("Content-Length is required")) - .with_kind(ErrorKind::Network) - } - Some(Err(_)) => { - return Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::from("Invalid Content-Length")) - .with_kind(ErrorKind::Network) - } - Some(Ok(a)) => match a.parse::() { - Err(_) => { - return Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::from("Invalid Content-Length")) - .with_kind(ErrorKind::Network) - } - Ok(a) => a, - }, - }; - - handle - .progress - .send_modify(|p| p.expected_size = Some(content_length)); - - let mut body = request.into_body().into_data_stream(); - while let Some(next) = body.next().await { - if let Err(e) = async { - handle - .write_all(&next.map_err(|e| { - std::io::Error::new(std::io::ErrorKind::Other, e) - })?) - .await?; - Ok(()) - } - .await - { - handle.progress.send_if_modified(|p| p.handle_error(&e)); - break; - } - } + handle.upload(request).await; Response::builder() .status(StatusCode::NO_CONTENT) @@ -364,6 +322,46 @@ pub struct UploadHandle { file: File, progress: watch::Sender, } +impl UploadHandle { + pub async fn upload(&mut self, request: Request) { + self.process_headers(request.headers()); + self.process_body(request.into_body().into_data_stream()) + .await; + } + pub async fn download(&mut self, response: reqwest::Response) { + self.process_headers(response.headers()); + self.process_body(response.bytes_stream()).await; + } + fn process_headers(&mut self, headers: &HeaderMap) { + if let Some(content_length) = headers + .get(CONTENT_LENGTH) + .and_then(|a| a.to_str().log_err()) + .and_then(|a| a.parse::().log_err()) + { + self.progress + .send_modify(|p| p.expected_size = Some(content_length)); + } + } + async fn process_body>>( + &mut self, + mut body: impl Stream> + Unpin, + ) { + while let Some(next) = body.next().await { + if let Err(e) = async { + self.write_all( + &next.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?, + ) + .await?; + Ok(()) + } + .await + { + self.progress.send_if_modified(|p| p.handle_error(&e)); + break; + } + } + } +} #[pin_project::pinned_drop] impl PinnedDrop for UploadHandle { fn drop(self: Pin<&mut Self>) {