mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-03-26 02:11:53 +00:00
[feat]: resumable downloads (#1746)
* optimize tests * add test * resumable downloads
This commit is contained in:
@@ -144,3 +144,6 @@ url = { version = "2.2.2", features = ["serde"] }
|
|||||||
|
|
||||||
[profile.dev.package.backtrace]
|
[profile.dev.package.backtrace]
|
||||||
opt-level = 3
|
opt-level = 3
|
||||||
|
|
||||||
|
[profile.test]
|
||||||
|
opt-level = 3
|
||||||
|
|||||||
@@ -7,12 +7,13 @@ use std::task::{Context, Poll};
|
|||||||
|
|
||||||
use color_eyre::eyre::eyre;
|
use color_eyre::eyre::eyre;
|
||||||
use futures::future::BoxFuture;
|
use futures::future::BoxFuture;
|
||||||
use futures::FutureExt;
|
use futures::stream::BoxStream;
|
||||||
|
use futures::{FutureExt, StreamExt};
|
||||||
use http::header::{ACCEPT_RANGES, CONTENT_LENGTH, RANGE};
|
use http::header::{ACCEPT_RANGES, CONTENT_LENGTH, RANGE};
|
||||||
|
use hyper::body::Bytes;
|
||||||
use pin_project::pin_project;
|
use pin_project::pin_project;
|
||||||
use reqwest::{Client, Url};
|
use reqwest::{Client, Url};
|
||||||
use tokio::io::{AsyncRead, AsyncSeek};
|
use tokio::io::{AsyncRead, AsyncSeek};
|
||||||
use tracing::trace;
|
|
||||||
|
|
||||||
use crate::{Error, ResultExt};
|
use crate::{Error, ResultExt};
|
||||||
|
|
||||||
@@ -23,7 +24,20 @@ pub struct HttpReader {
|
|||||||
http_client: Client,
|
http_client: Client,
|
||||||
total_bytes: usize,
|
total_bytes: usize,
|
||||||
range_unit: Option<RangeUnit>,
|
range_unit: Option<RangeUnit>,
|
||||||
read_in_progress: Option<BoxFuture<'static, Result<Vec<u8>, Error>>>,
|
read_in_progress: ReadInProgress,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ReadInProgress {
|
||||||
|
None,
|
||||||
|
InProgress(
|
||||||
|
BoxFuture<'static, Result<BoxStream<'static, Result<Bytes, reqwest::Error>>, Error>>,
|
||||||
|
),
|
||||||
|
Complete(BoxStream<'static, Result<Bytes, reqwest::Error>>),
|
||||||
|
}
|
||||||
|
impl ReadInProgress {
|
||||||
|
fn take(&mut self) -> Self {
|
||||||
|
std::mem::replace(self, Self::None)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we want to add support for units other than Accept-Ranges: bytes, we can use this enum
|
// If we want to add support for units other than Accept-Ranges: bytes, we can use this enum
|
||||||
@@ -31,6 +45,11 @@ pub struct HttpReader {
|
|||||||
enum RangeUnit {
|
enum RangeUnit {
|
||||||
Bytes,
|
Bytes,
|
||||||
}
|
}
|
||||||
|
impl Default for RangeUnit {
|
||||||
|
fn default() -> Self {
|
||||||
|
RangeUnit::Bytes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Display for RangeUnit {
|
impl Display for RangeUnit {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
@@ -110,7 +129,7 @@ impl HttpReader {
|
|||||||
http_client,
|
http_client,
|
||||||
total_bytes,
|
total_bytes,
|
||||||
range_unit,
|
range_unit,
|
||||||
read_in_progress: None,
|
read_in_progress: ReadInProgress::None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,42 +141,25 @@ impl HttpReader {
|
|||||||
start: usize,
|
start: usize,
|
||||||
len: usize,
|
len: usize,
|
||||||
total_bytes: usize,
|
total_bytes: usize,
|
||||||
) -> Result<Vec<u8>, Error> {
|
) -> Result<BoxStream<'static, Result<Bytes, reqwest::Error>>, Error> {
|
||||||
let mut data = Vec::with_capacity(len);
|
|
||||||
|
|
||||||
let end = min(start + len, total_bytes) - 1;
|
let end = min(start + len, total_bytes) - 1;
|
||||||
|
|
||||||
if start > end {
|
if start > end {
|
||||||
return Ok(data);
|
return Ok(futures::stream::empty().boxed());
|
||||||
}
|
}
|
||||||
|
|
||||||
match range_unit {
|
let data_range = format!("{}={}-{} ", range_unit.unwrap_or_default(), start, end);
|
||||||
Some(unit) => {
|
|
||||||
let data_range = format!("{}={}-{} ", unit, start, end);
|
|
||||||
trace!("get range alive? {}", data_range);
|
|
||||||
|
|
||||||
let data_resp = http_client
|
let data_resp = http_client
|
||||||
.get(http_url)
|
.get(http_url)
|
||||||
.header(RANGE, data_range)
|
.header(RANGE, data_range)
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.with_kind(crate::ErrorKind::InvalidRequest)?;
|
.with_kind(crate::ErrorKind::Network)?
|
||||||
|
.error_for_status()
|
||||||
|
.with_kind(crate::ErrorKind::Network)?;
|
||||||
|
|
||||||
let status_code = data_resp.status();
|
Ok(data_resp.bytes_stream().boxed())
|
||||||
//let data_res = data_resp.bytes().await;
|
|
||||||
if status_code.is_success() {
|
|
||||||
data = data_resp
|
|
||||||
.bytes()
|
|
||||||
.await
|
|
||||||
.with_kind(crate::ErrorKind::BytesError)?
|
|
||||||
.to_vec();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
None => unreachable!(),
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(data)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,51 +169,88 @@ impl AsyncRead for HttpReader {
|
|||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
buf: &mut tokio::io::ReadBuf<'_>,
|
buf: &mut tokio::io::ReadBuf<'_>,
|
||||||
) -> Poll<std::io::Result<()>> {
|
) -> Poll<std::io::Result<()>> {
|
||||||
let this = self.project();
|
fn poll_complete(
|
||||||
|
body: &mut BoxStream<'static, Result<Bytes, reqwest::Error>>,
|
||||||
let mut fut = if let Some(fut) = this.read_in_progress.take() {
|
cx: &mut Context<'_>,
|
||||||
fut
|
buf: &mut tokio::io::ReadBuf<'_>,
|
||||||
} else {
|
) -> Poll<Option<std::io::Result<usize>>> {
|
||||||
HttpReader::get_range(
|
Poll::Ready(match futures::ready!(body.as_mut().poll_next(cx)) {
|
||||||
*this.range_unit,
|
Some(Ok(bytes)) => {
|
||||||
this.http_client.clone(),
|
if buf.remaining() < bytes.len() {
|
||||||
this.http_url.clone(),
|
Some(Err(StdIOError::new(
|
||||||
*this.cursor_pos,
|
std::io::ErrorKind::InvalidInput,
|
||||||
buf.remaining(),
|
format!("more bytes returned than expected"),
|
||||||
*this.total_bytes,
|
)))
|
||||||
)
|
|
||||||
.boxed()
|
|
||||||
};
|
|
||||||
|
|
||||||
let res_poll = fut.as_mut().poll(cx);
|
|
||||||
trace!("Polled with remaining bytes in buf: {}", buf.remaining());
|
|
||||||
|
|
||||||
match res_poll {
|
|
||||||
Poll::Ready(result) => match result {
|
|
||||||
Ok(data_chunk) => {
|
|
||||||
trace!("data chunk: len: {}", data_chunk.len());
|
|
||||||
trace!("buf filled len: {}", buf.filled().len());
|
|
||||||
|
|
||||||
if data_chunk.len() <= buf.remaining() {
|
|
||||||
buf.put_slice(&data_chunk);
|
|
||||||
*this.cursor_pos += data_chunk.len();
|
|
||||||
|
|
||||||
Poll::Ready(Ok(()))
|
|
||||||
} else {
|
} else {
|
||||||
buf.put_slice(&data_chunk);
|
buf.put_slice(&*bytes);
|
||||||
|
Some(Ok(bytes.len()))
|
||||||
Poll::Ready(Ok(()))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(err) => Poll::Ready(Err(StdIOError::new(
|
Some(Err(e)) => Some(Err(StdIOError::new(std::io::ErrorKind::Interrupted, e))),
|
||||||
std::io::ErrorKind::Interrupted,
|
None => None,
|
||||||
Box::<dyn std::error::Error + Send + Sync>::from(err.source),
|
})
|
||||||
))),
|
}
|
||||||
},
|
let this = self.project();
|
||||||
Poll::Pending => {
|
|
||||||
*this.read_in_progress = Some(fut);
|
|
||||||
|
|
||||||
Poll::Pending
|
loop {
|
||||||
|
let mut in_progress = match this.read_in_progress.take() {
|
||||||
|
ReadInProgress::Complete(mut body) => match poll_complete(&mut body, cx, buf) {
|
||||||
|
Poll::Pending => {
|
||||||
|
*this.read_in_progress = ReadInProgress::Complete(body);
|
||||||
|
return Poll::Pending;
|
||||||
|
}
|
||||||
|
Poll::Ready(Some(Ok(len))) => {
|
||||||
|
*this.read_in_progress = ReadInProgress::Complete(body);
|
||||||
|
*this.cursor_pos += len;
|
||||||
|
|
||||||
|
return Poll::Ready(Ok(()));
|
||||||
|
}
|
||||||
|
Poll::Ready(res) => {
|
||||||
|
if let Some(Err(e)) = res {
|
||||||
|
tracing::error!(
|
||||||
|
"Error reading bytes from {}: {}, attempting to resume download",
|
||||||
|
this.http_url,
|
||||||
|
e
|
||||||
|
);
|
||||||
|
tracing::debug!("{:?}", e);
|
||||||
|
}
|
||||||
|
if *this.cursor_pos == *this.total_bytes {
|
||||||
|
return Poll::Ready(Ok(()));
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ReadInProgress::None => HttpReader::get_range(
|
||||||
|
*this.range_unit,
|
||||||
|
this.http_client.clone(),
|
||||||
|
this.http_url.clone(),
|
||||||
|
*this.cursor_pos,
|
||||||
|
buf.remaining(),
|
||||||
|
*this.total_bytes,
|
||||||
|
)
|
||||||
|
.boxed(),
|
||||||
|
ReadInProgress::InProgress(fut) => fut,
|
||||||
|
};
|
||||||
|
|
||||||
|
let res_poll = in_progress.as_mut().poll(cx);
|
||||||
|
|
||||||
|
match res_poll {
|
||||||
|
Poll::Ready(result) => match result {
|
||||||
|
Ok(body) => {
|
||||||
|
*this.read_in_progress = ReadInProgress::Complete(body);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
break Poll::Ready(Err(StdIOError::new(
|
||||||
|
std::io::ErrorKind::Interrupted,
|
||||||
|
Box::<dyn std::error::Error + Send + Sync>::from(err.source),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Poll::Pending => {
|
||||||
|
*this.read_in_progress = ReadInProgress::InProgress(in_progress);
|
||||||
|
|
||||||
|
break Poll::Pending;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -221,6 +260,8 @@ impl AsyncSeek for HttpReader {
|
|||||||
fn start_seek(self: Pin<&mut Self>, position: std::io::SeekFrom) -> std::io::Result<()> {
|
fn start_seek(self: Pin<&mut Self>, position: std::io::SeekFrom) -> std::io::Result<()> {
|
||||||
let this = self.project();
|
let this = self.project();
|
||||||
|
|
||||||
|
this.read_in_progress.take(); // invalidate any existing reads
|
||||||
|
|
||||||
match position {
|
match position {
|
||||||
std::io::SeekFrom::Start(offset) => {
|
std::io::SeekFrom::Start(offset) => {
|
||||||
let pos_res = usize::try_from(offset);
|
let pos_res = usize::try_from(offset);
|
||||||
@@ -285,16 +326,32 @@ impl AsyncSeek for HttpReader {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn main_test() {
|
async fn main_test() {
|
||||||
use tokio::io::AsyncReadExt;
|
|
||||||
let http_url = Url::parse("https://start9.com/latest/_static/css/main.css").unwrap();
|
let http_url = Url::parse("https://start9.com/latest/_static/css/main.css").unwrap();
|
||||||
|
|
||||||
println!("Getting this resource: {}", http_url);
|
println!("Getting this resource: {}", http_url);
|
||||||
let mut test_reader = HttpReader::new(http_url).await.unwrap();
|
let mut test_reader = HttpReader::new(http_url).await.unwrap();
|
||||||
|
|
||||||
let mut buf = vec![0; test_reader.total_bytes];
|
let mut buf = Vec::new();
|
||||||
let bytes_read = test_reader.read(&mut buf).await.unwrap();
|
|
||||||
|
|
||||||
println!("bytes read: {}", bytes_read);
|
tokio::io::copy(&mut test_reader, &mut buf).await.unwrap();
|
||||||
|
|
||||||
//println!("{}", String::from_utf8(buf).unwrap());
|
assert_eq!(buf.len(), test_reader.total_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn s9pk_test() {
|
||||||
|
use tokio::io::BufReader;
|
||||||
|
|
||||||
|
let http_url = Url::parse("https://github.com/Start9Labs/hello-world-wrapper/releases/download/v0.3.0/hello-world.s9pk").unwrap();
|
||||||
|
|
||||||
|
println!("Getting this resource: {}", http_url);
|
||||||
|
let mut test_reader =
|
||||||
|
BufReader::with_capacity(1024 * 1024, HttpReader::new(http_url).await.unwrap());
|
||||||
|
|
||||||
|
let mut s9pk = crate::s9pk::reader::S9pkReader::from_reader(test_reader, true)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let manifest = s9pk.manifest().await.unwrap();
|
||||||
|
assert_eq!(&**manifest.id, "hello-world");
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user