diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 18ed8b17e..ca453ceba 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -144,3 +144,6 @@ url = { version = "2.2.2", features = ["serde"] } [profile.dev.package.backtrace] opt-level = 3 + +[profile.test] +opt-level = 3 diff --git a/backend/src/util/http_reader.rs b/backend/src/util/http_reader.rs index 2ed64a9c1..3bc4d4cd9 100644 --- a/backend/src/util/http_reader.rs +++ b/backend/src/util/http_reader.rs @@ -7,12 +7,13 @@ use std::task::{Context, Poll}; use color_eyre::eyre::eyre; use futures::future::BoxFuture; -use futures::FutureExt; +use futures::stream::BoxStream; +use futures::{FutureExt, StreamExt}; use http::header::{ACCEPT_RANGES, CONTENT_LENGTH, RANGE}; +use hyper::body::Bytes; use pin_project::pin_project; use reqwest::{Client, Url}; use tokio::io::{AsyncRead, AsyncSeek}; -use tracing::trace; use crate::{Error, ResultExt}; @@ -23,7 +24,20 @@ pub struct HttpReader { http_client: Client, total_bytes: usize, range_unit: Option, - read_in_progress: Option, Error>>>, + read_in_progress: ReadInProgress, +} + +enum ReadInProgress { + None, + InProgress( + BoxFuture<'static, Result>, Error>>, + ), + Complete(BoxStream<'static, Result>), +} +impl ReadInProgress { + fn take(&mut self) -> Self { + std::mem::replace(self, Self::None) + } } // If we want to add support for units other than Accept-Ranges: bytes, we can use this enum @@ -31,6 +45,11 @@ pub struct HttpReader { enum RangeUnit { Bytes, } +impl Default for RangeUnit { + fn default() -> Self { + RangeUnit::Bytes + } +} impl Display for RangeUnit { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -110,7 +129,7 @@ impl HttpReader { http_client, total_bytes, range_unit, - read_in_progress: None, + read_in_progress: ReadInProgress::None, }) } @@ -122,42 +141,25 @@ impl HttpReader { start: usize, len: usize, total_bytes: usize, - ) -> Result, Error> { - let mut data = Vec::with_capacity(len); - + ) -> Result>, Error> { let end = min(start + len, total_bytes) - 1; if start > end { - return Ok(data); + return Ok(futures::stream::empty().boxed()); } - match range_unit { - Some(unit) => { - let data_range = format!("{}={}-{} ", unit, start, end); - trace!("get range alive? {}", data_range); + let data_range = format!("{}={}-{} ", range_unit.unwrap_or_default(), start, end); - let data_resp = http_client - .get(http_url) - .header(RANGE, data_range) - .send() - .await - .with_kind(crate::ErrorKind::InvalidRequest)?; + let data_resp = http_client + .get(http_url) + .header(RANGE, data_range) + .send() + .await + .with_kind(crate::ErrorKind::Network)? + .error_for_status() + .with_kind(crate::ErrorKind::Network)?; - let status_code = data_resp.status(); - //let data_res = data_resp.bytes().await; - if status_code.is_success() { - data = data_resp - .bytes() - .await - .with_kind(crate::ErrorKind::BytesError)? - .to_vec(); - } - } - - None => unreachable!(), - } - - Ok(data) + Ok(data_resp.bytes_stream().boxed()) } } @@ -167,51 +169,88 @@ impl AsyncRead for HttpReader { cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { - let this = self.project(); - - let mut fut = if let Some(fut) = this.read_in_progress.take() { - fut - } else { - HttpReader::get_range( - *this.range_unit, - this.http_client.clone(), - this.http_url.clone(), - *this.cursor_pos, - buf.remaining(), - *this.total_bytes, - ) - .boxed() - }; - - let res_poll = fut.as_mut().poll(cx); - trace!("Polled with remaining bytes in buf: {}", buf.remaining()); - - match res_poll { - Poll::Ready(result) => match result { - Ok(data_chunk) => { - trace!("data chunk: len: {}", data_chunk.len()); - trace!("buf filled len: {}", buf.filled().len()); - - if data_chunk.len() <= buf.remaining() { - buf.put_slice(&data_chunk); - *this.cursor_pos += data_chunk.len(); - - Poll::Ready(Ok(())) + fn poll_complete( + body: &mut BoxStream<'static, Result>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll>> { + Poll::Ready(match futures::ready!(body.as_mut().poll_next(cx)) { + Some(Ok(bytes)) => { + if buf.remaining() < bytes.len() { + Some(Err(StdIOError::new( + std::io::ErrorKind::InvalidInput, + format!("more bytes returned than expected"), + ))) } else { - buf.put_slice(&data_chunk); - - Poll::Ready(Ok(())) + buf.put_slice(&*bytes); + Some(Ok(bytes.len())) } } - Err(err) => Poll::Ready(Err(StdIOError::new( - std::io::ErrorKind::Interrupted, - Box::::from(err.source), - ))), - }, - Poll::Pending => { - *this.read_in_progress = Some(fut); + Some(Err(e)) => Some(Err(StdIOError::new(std::io::ErrorKind::Interrupted, e))), + None => None, + }) + } + let this = self.project(); - Poll::Pending + loop { + let mut in_progress = match this.read_in_progress.take() { + ReadInProgress::Complete(mut body) => match poll_complete(&mut body, cx, buf) { + Poll::Pending => { + *this.read_in_progress = ReadInProgress::Complete(body); + return Poll::Pending; + } + Poll::Ready(Some(Ok(len))) => { + *this.read_in_progress = ReadInProgress::Complete(body); + *this.cursor_pos += len; + + return Poll::Ready(Ok(())); + } + Poll::Ready(res) => { + if let Some(Err(e)) = res { + tracing::error!( + "Error reading bytes from {}: {}, attempting to resume download", + this.http_url, + e + ); + tracing::debug!("{:?}", e); + } + if *this.cursor_pos == *this.total_bytes { + return Poll::Ready(Ok(())); + } + continue; + } + }, + ReadInProgress::None => HttpReader::get_range( + *this.range_unit, + this.http_client.clone(), + this.http_url.clone(), + *this.cursor_pos, + buf.remaining(), + *this.total_bytes, + ) + .boxed(), + ReadInProgress::InProgress(fut) => fut, + }; + + let res_poll = in_progress.as_mut().poll(cx); + + match res_poll { + Poll::Ready(result) => match result { + Ok(body) => { + *this.read_in_progress = ReadInProgress::Complete(body); + } + Err(err) => { + break Poll::Ready(Err(StdIOError::new( + std::io::ErrorKind::Interrupted, + Box::::from(err.source), + ))); + } + }, + Poll::Pending => { + *this.read_in_progress = ReadInProgress::InProgress(in_progress); + + break Poll::Pending; + } } } } @@ -221,6 +260,8 @@ impl AsyncSeek for HttpReader { fn start_seek(self: Pin<&mut Self>, position: std::io::SeekFrom) -> std::io::Result<()> { let this = self.project(); + this.read_in_progress.take(); // invalidate any existing reads + match position { std::io::SeekFrom::Start(offset) => { let pos_res = usize::try_from(offset); @@ -285,16 +326,32 @@ impl AsyncSeek for HttpReader { #[tokio::test] async fn main_test() { - use tokio::io::AsyncReadExt; let http_url = Url::parse("https://start9.com/latest/_static/css/main.css").unwrap(); println!("Getting this resource: {}", http_url); let mut test_reader = HttpReader::new(http_url).await.unwrap(); - let mut buf = vec![0; test_reader.total_bytes]; - let bytes_read = test_reader.read(&mut buf).await.unwrap(); + let mut buf = Vec::new(); - println!("bytes read: {}", bytes_read); + tokio::io::copy(&mut test_reader, &mut buf).await.unwrap(); - //println!("{}", String::from_utf8(buf).unwrap()); + assert_eq!(buf.len(), test_reader.total_bytes) +} + +#[tokio::test] +async fn s9pk_test() { + use tokio::io::BufReader; + + let http_url = Url::parse("https://github.com/Start9Labs/hello-world-wrapper/releases/download/v0.3.0/hello-world.s9pk").unwrap(); + + println!("Getting this resource: {}", http_url); + let mut test_reader = + BufReader::with_capacity(1024 * 1024, HttpReader::new(http_url).await.unwrap()); + + let mut s9pk = crate::s9pk::reader::S9pkReader::from_reader(test_reader, true) + .await + .unwrap(); + + let manifest = s9pk.manifest().await.unwrap(); + assert_eq!(&**manifest.id, "hello-world"); }