[feat]: resumable downloads (#1746)

* optimize tests

* add test

* resumable downloads
This commit is contained in:
Aiden McClelland
2022-08-24 15:22:49 -06:00
committed by GitHub
parent 2dd31fa93f
commit b2d7f4f606
2 changed files with 139 additions and 79 deletions

View File

@@ -144,3 +144,6 @@ url = { version = "2.2.2", features = ["serde"] }
[profile.dev.package.backtrace] [profile.dev.package.backtrace]
opt-level = 3 opt-level = 3
[profile.test]
opt-level = 3

View File

@@ -7,12 +7,13 @@ use std::task::{Context, Poll};
use color_eyre::eyre::eyre; use color_eyre::eyre::eyre;
use futures::future::BoxFuture; use futures::future::BoxFuture;
use futures::FutureExt; use futures::stream::BoxStream;
use futures::{FutureExt, StreamExt};
use http::header::{ACCEPT_RANGES, CONTENT_LENGTH, RANGE}; use http::header::{ACCEPT_RANGES, CONTENT_LENGTH, RANGE};
use hyper::body::Bytes;
use pin_project::pin_project; use pin_project::pin_project;
use reqwest::{Client, Url}; use reqwest::{Client, Url};
use tokio::io::{AsyncRead, AsyncSeek}; use tokio::io::{AsyncRead, AsyncSeek};
use tracing::trace;
use crate::{Error, ResultExt}; use crate::{Error, ResultExt};
@@ -23,7 +24,20 @@ pub struct HttpReader {
http_client: Client, http_client: Client,
total_bytes: usize, total_bytes: usize,
range_unit: Option<RangeUnit>, range_unit: Option<RangeUnit>,
read_in_progress: Option<BoxFuture<'static, Result<Vec<u8>, Error>>>, read_in_progress: ReadInProgress,
}
enum ReadInProgress {
None,
InProgress(
BoxFuture<'static, Result<BoxStream<'static, Result<Bytes, reqwest::Error>>, Error>>,
),
Complete(BoxStream<'static, Result<Bytes, reqwest::Error>>),
}
impl ReadInProgress {
fn take(&mut self) -> Self {
std::mem::replace(self, Self::None)
}
} }
// If we want to add support for units other than Accept-Ranges: bytes, we can use this enum // If we want to add support for units other than Accept-Ranges: bytes, we can use this enum
@@ -31,6 +45,11 @@ pub struct HttpReader {
enum RangeUnit { enum RangeUnit {
Bytes, Bytes,
} }
impl Default for RangeUnit {
fn default() -> Self {
RangeUnit::Bytes
}
}
impl Display for RangeUnit { impl Display for RangeUnit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -110,7 +129,7 @@ impl HttpReader {
http_client, http_client,
total_bytes, total_bytes,
range_unit, range_unit,
read_in_progress: None, read_in_progress: ReadInProgress::None,
}) })
} }
@@ -122,42 +141,25 @@ impl HttpReader {
start: usize, start: usize,
len: usize, len: usize,
total_bytes: usize, total_bytes: usize,
) -> Result<Vec<u8>, Error> { ) -> Result<BoxStream<'static, Result<Bytes, reqwest::Error>>, Error> {
let mut data = Vec::with_capacity(len);
let end = min(start + len, total_bytes) - 1; let end = min(start + len, total_bytes) - 1;
if start > end { if start > end {
return Ok(data); return Ok(futures::stream::empty().boxed());
} }
match range_unit { let data_range = format!("{}={}-{} ", range_unit.unwrap_or_default(), start, end);
Some(unit) => {
let data_range = format!("{}={}-{} ", unit, start, end);
trace!("get range alive? {}", data_range);
let data_resp = http_client let data_resp = http_client
.get(http_url) .get(http_url)
.header(RANGE, data_range) .header(RANGE, data_range)
.send() .send()
.await .await
.with_kind(crate::ErrorKind::InvalidRequest)?; .with_kind(crate::ErrorKind::Network)?
.error_for_status()
.with_kind(crate::ErrorKind::Network)?;
let status_code = data_resp.status(); Ok(data_resp.bytes_stream().boxed())
//let data_res = data_resp.bytes().await;
if status_code.is_success() {
data = data_resp
.bytes()
.await
.with_kind(crate::ErrorKind::BytesError)?
.to_vec();
}
}
None => unreachable!(),
}
Ok(data)
} }
} }
@@ -167,51 +169,88 @@ impl AsyncRead for HttpReader {
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>, buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> { ) -> Poll<std::io::Result<()>> {
let this = self.project(); fn poll_complete(
body: &mut BoxStream<'static, Result<Bytes, reqwest::Error>>,
let mut fut = if let Some(fut) = this.read_in_progress.take() { cx: &mut Context<'_>,
fut buf: &mut tokio::io::ReadBuf<'_>,
} else { ) -> Poll<Option<std::io::Result<usize>>> {
HttpReader::get_range( Poll::Ready(match futures::ready!(body.as_mut().poll_next(cx)) {
*this.range_unit, Some(Ok(bytes)) => {
this.http_client.clone(), if buf.remaining() < bytes.len() {
this.http_url.clone(), Some(Err(StdIOError::new(
*this.cursor_pos, std::io::ErrorKind::InvalidInput,
buf.remaining(), format!("more bytes returned than expected"),
*this.total_bytes, )))
)
.boxed()
};
let res_poll = fut.as_mut().poll(cx);
trace!("Polled with remaining bytes in buf: {}", buf.remaining());
match res_poll {
Poll::Ready(result) => match result {
Ok(data_chunk) => {
trace!("data chunk: len: {}", data_chunk.len());
trace!("buf filled len: {}", buf.filled().len());
if data_chunk.len() <= buf.remaining() {
buf.put_slice(&data_chunk);
*this.cursor_pos += data_chunk.len();
Poll::Ready(Ok(()))
} else { } else {
buf.put_slice(&data_chunk); buf.put_slice(&*bytes);
Some(Ok(bytes.len()))
Poll::Ready(Ok(()))
} }
} }
Err(err) => Poll::Ready(Err(StdIOError::new( Some(Err(e)) => Some(Err(StdIOError::new(std::io::ErrorKind::Interrupted, e))),
std::io::ErrorKind::Interrupted, None => None,
Box::<dyn std::error::Error + Send + Sync>::from(err.source), })
))), }
}, let this = self.project();
Poll::Pending => {
*this.read_in_progress = Some(fut);
Poll::Pending loop {
let mut in_progress = match this.read_in_progress.take() {
ReadInProgress::Complete(mut body) => match poll_complete(&mut body, cx, buf) {
Poll::Pending => {
*this.read_in_progress = ReadInProgress::Complete(body);
return Poll::Pending;
}
Poll::Ready(Some(Ok(len))) => {
*this.read_in_progress = ReadInProgress::Complete(body);
*this.cursor_pos += len;
return Poll::Ready(Ok(()));
}
Poll::Ready(res) => {
if let Some(Err(e)) = res {
tracing::error!(
"Error reading bytes from {}: {}, attempting to resume download",
this.http_url,
e
);
tracing::debug!("{:?}", e);
}
if *this.cursor_pos == *this.total_bytes {
return Poll::Ready(Ok(()));
}
continue;
}
},
ReadInProgress::None => HttpReader::get_range(
*this.range_unit,
this.http_client.clone(),
this.http_url.clone(),
*this.cursor_pos,
buf.remaining(),
*this.total_bytes,
)
.boxed(),
ReadInProgress::InProgress(fut) => fut,
};
let res_poll = in_progress.as_mut().poll(cx);
match res_poll {
Poll::Ready(result) => match result {
Ok(body) => {
*this.read_in_progress = ReadInProgress::Complete(body);
}
Err(err) => {
break Poll::Ready(Err(StdIOError::new(
std::io::ErrorKind::Interrupted,
Box::<dyn std::error::Error + Send + Sync>::from(err.source),
)));
}
},
Poll::Pending => {
*this.read_in_progress = ReadInProgress::InProgress(in_progress);
break Poll::Pending;
}
} }
} }
} }
@@ -221,6 +260,8 @@ impl AsyncSeek for HttpReader {
fn start_seek(self: Pin<&mut Self>, position: std::io::SeekFrom) -> std::io::Result<()> { fn start_seek(self: Pin<&mut Self>, position: std::io::SeekFrom) -> std::io::Result<()> {
let this = self.project(); let this = self.project();
this.read_in_progress.take(); // invalidate any existing reads
match position { match position {
std::io::SeekFrom::Start(offset) => { std::io::SeekFrom::Start(offset) => {
let pos_res = usize::try_from(offset); let pos_res = usize::try_from(offset);
@@ -285,16 +326,32 @@ impl AsyncSeek for HttpReader {
#[tokio::test] #[tokio::test]
async fn main_test() { async fn main_test() {
use tokio::io::AsyncReadExt;
let http_url = Url::parse("https://start9.com/latest/_static/css/main.css").unwrap(); let http_url = Url::parse("https://start9.com/latest/_static/css/main.css").unwrap();
println!("Getting this resource: {}", http_url); println!("Getting this resource: {}", http_url);
let mut test_reader = HttpReader::new(http_url).await.unwrap(); let mut test_reader = HttpReader::new(http_url).await.unwrap();
let mut buf = vec![0; test_reader.total_bytes]; let mut buf = Vec::new();
let bytes_read = test_reader.read(&mut buf).await.unwrap();
println!("bytes read: {}", bytes_read); tokio::io::copy(&mut test_reader, &mut buf).await.unwrap();
//println!("{}", String::from_utf8(buf).unwrap()); assert_eq!(buf.len(), test_reader.total_bytes)
}
#[tokio::test]
async fn s9pk_test() {
use tokio::io::BufReader;
let http_url = Url::parse("https://github.com/Start9Labs/hello-world-wrapper/releases/download/v0.3.0/hello-world.s9pk").unwrap();
println!("Getting this resource: {}", http_url);
let mut test_reader =
BufReader::with_capacity(1024 * 1024, HttpReader::new(http_url).await.unwrap());
let mut s9pk = crate::s9pk::reader::S9pkReader::from_reader(test_reader, true)
.await
.unwrap();
let manifest = s9pk.manifest().await.unwrap();
assert_eq!(&**manifest.id, "hello-world");
} }