squashfs-wip

This commit is contained in:
Aiden McClelland
2025-11-07 03:13:54 -07:00
parent 68f401bfa3
commit e7847d0e88
2 changed files with 192 additions and 262 deletions

View File

@@ -49,7 +49,7 @@ pub mod net;
pub mod rpc; pub mod rpc;
pub mod rpc_client; pub mod rpc_client;
pub mod serde; pub mod serde;
//pub mod squashfs; pub mod squashfs;
pub mod sync; pub mod sync;
pub mod tui; pub mod tui;

View File

@@ -98,8 +98,7 @@ impl<W: Write> Visit<SquashfsSerializer<W>> for Superblock {
#[pin_project::pin_project] #[pin_project::pin_project]
pub struct MetadataBlocksWriter<W> { pub struct MetadataBlocksWriter<W> {
input: [u8; 8192], input: PartialBuffer<[u8; 8192]>,
size: usize,
size_addr: Option<u64>, size_addr: Option<u64>,
output: PartialBuffer<[u8; 8192]>, output: PartialBuffer<[u8; 8192]>,
output_flushed: usize, output_flushed: usize,
@@ -123,25 +122,29 @@ enum WriteState {
WritingSizeHeader(u16), WritingSizeHeader(u16),
WritingOutput(Box<Self>), WritingOutput(Box<Self>),
EncodingInput, EncodingInput,
FinishingCompression,
WritingFinalSizeHeader(u64, u64),
SeekingToEnd(u64), SeekingToEnd(u64),
} }
fn poll_seek_helper<W: AsyncSeek>( fn poll_seek_helper<W: AsyncSeek>(
writer: std::pin::Pin<&mut W>, mut writer: std::pin::Pin<&mut W>,
seek_state: &mut SeekState, seek_state: &mut SeekState,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
pos: u64, pos: u64,
) -> std::task::Poll<std::io::Result<u64>> { ) -> std::task::Poll<std::io::Result<u64>> {
match *seek_state { match *seek_state {
SeekState::Idle => { SeekState::Idle => {
writer.start_seek(std::io::SeekFrom::Start(pos))?; writer.as_mut().start_seek(std::io::SeekFrom::Start(pos))?;
*seek_state = SeekState::Seeking(pos); *seek_state = SeekState::Seeking(pos);
Poll::Pending match writer.as_mut().poll_complete(cx)? {
Poll::Ready(result) => {
*seek_state = SeekState::Idle;
Poll::Ready(Ok(result))
}
Poll::Pending => Poll::Pending,
}
} }
SeekState::Seeking(target) if target == pos => { SeekState::Seeking(target) if target == pos => {
let result = ready!(writer.poll_complete(cx))?; let result = ready!(writer.as_mut().poll_complete(cx))?;
*seek_state = SeekState::Idle; *seek_state = SeekState::Idle;
Poll::Ready(Ok(result)) Poll::Ready(Ok(result))
} }
@@ -151,35 +154,53 @@ fn poll_seek_helper<W: AsyncSeek>(
pos, pos,
old_target old_target
); );
writer.start_seek(std::io::SeekFrom::Start(pos))?; writer.as_mut().start_seek(std::io::SeekFrom::Start(pos))?;
*seek_state = SeekState::Seeking(pos); *seek_state = SeekState::Seeking(pos);
Poll::Pending match writer.as_mut().poll_complete(cx)? {
Poll::Ready(result) => {
*seek_state = SeekState::Idle;
Poll::Ready(Ok(result))
}
Poll::Pending => Poll::Pending,
}
} }
SeekState::GettingPosition => { SeekState::GettingPosition => {
tracing::warn!( tracing::warn!(
"poll_seek({}) called while getting stream position, canceling", "poll_seek({}) called while getting stream position, canceling",
pos pos
); );
writer.start_seek(std::io::SeekFrom::Start(pos))?; writer.as_mut().start_seek(std::io::SeekFrom::Start(pos))?;
*seek_state = SeekState::Seeking(pos); *seek_state = SeekState::Seeking(pos);
Poll::Pending match writer.as_mut().poll_complete(cx)? {
Poll::Ready(result) => {
*seek_state = SeekState::Idle;
Poll::Ready(Ok(result))
}
Poll::Pending => Poll::Pending,
}
} }
} }
} }
fn poll_stream_position_helper<W: AsyncSeek>( fn poll_stream_position_helper<W: AsyncSeek>(
writer: std::pin::Pin<&mut W>, mut writer: std::pin::Pin<&mut W>,
seek_state: &mut SeekState, seek_state: &mut SeekState,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<u64>> { ) -> std::task::Poll<std::io::Result<u64>> {
match *seek_state { match *seek_state {
SeekState::Idle => { SeekState::Idle => {
writer.start_seek(std::io::SeekFrom::Current(0))?; writer.as_mut().start_seek(std::io::SeekFrom::Current(0))?;
*seek_state = SeekState::GettingPosition; *seek_state = SeekState::GettingPosition;
Poll::Pending match writer.as_mut().poll_complete(cx)? {
Poll::Ready(result) => {
*seek_state = SeekState::Idle;
Poll::Ready(Ok(result))
}
Poll::Pending => Poll::Pending,
}
} }
SeekState::GettingPosition => { SeekState::GettingPosition => {
let result = ready!(writer.poll_complete(cx))?; let result = ready!(writer.as_mut().poll_complete(cx))?;
*seek_state = SeekState::Idle; *seek_state = SeekState::Idle;
Poll::Ready(Ok(result)) Poll::Ready(Ok(result))
} }
@@ -188,18 +209,22 @@ fn poll_stream_position_helper<W: AsyncSeek>(
"poll_stream_position called while seeking to {}, canceling", "poll_stream_position called while seeking to {}, canceling",
target target
); );
writer.start_seek(std::io::SeekFrom::Current(0))?; writer.as_mut().start_seek(std::io::SeekFrom::Current(0))?;
*seek_state = SeekState::GettingPosition; *seek_state = SeekState::GettingPosition;
Poll::Pending match writer.as_mut().poll_complete(cx)? {
Poll::Ready(result) => {
*seek_state = SeekState::Idle;
Poll::Ready(Ok(result))
}
Poll::Pending => Poll::Pending,
}
} }
} }
} }
impl<W: Write + Seek> Write for MetadataBlocksWriter<W> { impl<W: Write + Seek> Write for MetadataBlocksWriter<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let n = buf.len().min(self.input.len() - self.size); let n = self.input.copy_unwritten_from(&mut PartialBuffer::new(buf));
self.input[self.size..self.size + n].copy_from_slice(&buf[..n]);
self.size += n;
if n < buf.len() { if n < buf.len() {
self.flush()?; self.flush()?;
} }
@@ -207,9 +232,9 @@ impl<W: Write + Seek> Write for MetadataBlocksWriter<W> {
} }
fn flush(&mut self) -> std::io::Result<()> { fn flush(&mut self) -> std::io::Result<()> {
loop { loop {
match self.write_state { match &self.write_state {
WriteState::Idle => { WriteState::Idle => {
if self.size == 0 { if self.input.written().is_empty() {
return Ok(()); return Ok(());
} }
self.write_state = WriteState::WritingSizeHeader(0); self.write_state = WriteState::WritingSizeHeader(0);
@@ -218,12 +243,12 @@ impl<W: Write + Seek> Write for MetadataBlocksWriter<W> {
WriteState::WritingSizeHeader(size) => { WriteState::WritingSizeHeader(size) => {
let done = if let Some(size_addr) = self.size_addr { let done = if let Some(size_addr) = self.size_addr {
self.writer.seek(SeekFrom::Start(size_addr))?; self.writer.seek(SeekFrom::Start(size_addr))?;
Some(size_addr + size as u64) Some(size_addr + 2 + *size as u64)
} else { } else {
self.size_addr = Some(self.writer.stream_position()?); self.size_addr = Some(self.writer.stream_position()?);
None None
}; };
self.output.unwritten_mut()[..2].copy_from_slice(&u16::to_le_bytes(size)[..]); self.output.unwritten_mut()[..2].copy_from_slice(&u16::to_le_bytes(*size)[..]);
self.output.advance(2); self.output.advance(2);
self.write_state = self.write_state =
WriteState::WritingOutput(Box::new(if let Some(end) = done { WriteState::WritingOutput(Box::new(if let Some(end) = done {
@@ -242,80 +267,33 @@ impl<W: Write + Seek> Write for MetadataBlocksWriter<W> {
} else { } else {
self.output.reset(); self.output.reset();
self.output_flushed = 0; self.output_flushed = 0;
self.write_state = *next; self.write_state = *next.clone();
} }
} }
WriteState::EncodingInput => { WriteState::EncodingInput => {
let encoder = self.zstd.get_or_insert_with(|| ZstdEncoder::new(22)); let encoder = self.zstd.get_or_insert_with(|| ZstdEncoder::new(22));
let mut input = PartialBuffer::new(&self.input[..self.size]); encoder.encode(
while !self.output.unwritten().is_empty() && !input.unwritten().is_empty() { &mut PartialBuffer::new(&self.input.written()),
encoder.encode(&mut input, &mut self.output)?; &mut self.output,
} )?;
while !encoder.flush(&mut self.output)? {} let compressed = if !encoder.finish(&mut self.output)? {
while !encoder.finish(&mut self.output)? {} std::mem::swap(&mut self.output, &mut self.input);
if !self.output.unwritten().is_empty() { false
let mut input = } else {
PartialBuffer::new(&self.input[self.input_flushed..self.size]); true
encoder.encode(&mut input, &mut self.output)?; };
self.input_flushed += input.written().len(); self.zstd = None;
} self.input.reset();
self.write_state = WriteState::WritingOutput(Box::new()); self.write_state =
continue; WriteState::WritingOutput(Box::new(WriteState::WritingSizeHeader(
} self.output.written().len() as u16
| if compressed { 0 } else { 0x8000 },
WriteState::FinishingCompression => { )));
if !self.output.unwritten().is_empty() {
if self.zstd.as_mut().unwrap().finish(&mut self.output)? {
self.zstd = None;
}
}
if self.output.written().len() > self.output_flushed {
self.write_state = WriteState::WritingOutput;
continue;
}
if self.zstd.is_none() && self.output.written().len() == self.output_flushed {
self.output_flushed = 0;
self.output.reset();
let end_addr = self.writer.stream_position()?;
let size_addr = self.size_addr.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"size_addr not set when finishing compression",
)
})?;
self.write_state = WriteState::WritingFinalSizeHeader(size_addr, end_addr);
continue;
}
return Ok(());
}
WriteState::WritingFinalSizeHeader(size_addr, end_addr) => {
if self.output.written().len() > self.output_flushed {
let n = self
.writer
.write(&self.output.written()[self.output_flushed..])?;
self.output_flushed += n;
continue;
}
self.writer.seek(std::io::SeekFrom::Start(size_addr))?;
self.output.unwritten_mut()[..2]
.copy_from_slice(&((end_addr - size_addr - 2) as u16).to_le_bytes());
self.output.advance(2);
let n = self.writer.write(&self.output.written())?;
self.output_flushed = n;
if n == 2 {
self.output_flushed = 0;
self.output.reset();
self.write_state = WriteState::SeekingToEnd(end_addr);
}
continue;
} }
WriteState::SeekingToEnd(end_addr) => { WriteState::SeekingToEnd(end_addr) => {
self.writer.seek(std::io::SeekFrom::Start(end_addr))?; self.writer.seek(std::io::SeekFrom::Start(*end_addr))?;
self.input_flushed = 0;
self.size = 0;
self.size_addr = None; self.size_addr = None;
self.write_state = WriteState::Idle; self.write_state = WriteState::Idle;
return Ok(()); return Ok(());
@@ -332,11 +310,9 @@ impl<W: AsyncWrite + AsyncSeek> AsyncWrite for MetadataBlocksWriter<W> {
buf: &[u8], buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> { ) -> std::task::Poll<std::io::Result<usize>> {
let this = self.as_mut().project(); let this = self.as_mut().project();
let n = buf.len().min(this.input.len() - *this.size); let n = this.input.copy_unwritten_from(&mut PartialBuffer::new(buf));
this.input[*this.size..*this.size + n].copy_from_slice(&buf[..n]);
*this.size += n;
if n < buf.len() { if n < buf.len() {
ready!(self.poll_flush(cx)?); ready!(self.poll_flush(cx))?;
} }
Poll::Ready(Ok(n)) Poll::Ready(Ok(n))
} }
@@ -347,115 +323,76 @@ impl<W: AsyncWrite + AsyncSeek> AsyncWrite for MetadataBlocksWriter<W> {
) -> std::task::Poll<std::io::Result<()>> { ) -> std::task::Poll<std::io::Result<()>> {
loop { loop {
let mut this = self.as_mut().project(); let mut this = self.as_mut().project();
match *this.write_state { match this.write_state.clone() {
WriteState::Idle => { WriteState::Idle => {
if *this.size == 0 { if this.input.written().is_empty() {
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
if this.size_addr.is_none() { *this.write_state = WriteState::WritingSizeHeader(0);
}
WriteState::WritingSizeHeader(size) => {
let done = if let Some(size_addr) = *this.size_addr {
ready!(poll_seek_helper(
this.writer.as_mut(),
this.seek_state,
cx,
size_addr
))?;
Some(size_addr + 2 + size as u64)
} else {
let pos = ready!(poll_stream_position_helper( let pos = ready!(poll_stream_position_helper(
this.writer.as_mut(), this.writer.as_mut(),
this.seek_state, this.seek_state,
cx cx
))?; ))?;
*this.size_addr = Some(pos); *this.size_addr = Some(pos);
this.output.unwritten_mut()[..2].copy_from_slice(&[0; 2]); None
this.output.advance(2); };
} this.output.unwritten_mut()[..2]
*this.write_state = WriteState::WritingOutput; .copy_from_slice(&u16::to_le_bytes(size)[..]);
continue; this.output.advance(2);
} *this.write_state = WriteState::WritingOutput(Box::new(if let Some(end) = done {
WriteState::SeekingToEnd(end)
WriteState::WritingOutput => {
if this.output.written().len() > *this.output_flushed {
let n = ready!(
this.writer
.as_mut()
.poll_write(cx, &this.output.written()[*this.output_flushed..])
)?;
*this.output_flushed += n;
continue;
}
if this.output.written().len() == *this.output_flushed {
*this.output_flushed = 0;
this.output.reset();
}
if *this.input_flushed < *this.size {
if !this.output.unwritten().is_empty() {
let mut input =
PartialBuffer::new(&this.input[*this.input_flushed..*this.size]);
this.zstd
.get_or_insert_with(|| ZstdEncoder::new(22))
.encode(&mut input, this.output)?;
*this.input_flushed += input.written().len();
}
continue;
} else { } else {
if !this.output.unwritten().is_empty() { WriteState::EncodingInput
if this.zstd.as_mut().unwrap().finish(this.output)? { }));
*this.zstd = None;
}
continue;
}
if this.zstd.is_none()
&& this.output.written().len() == *this.output_flushed
{
*this.output_flushed = 0;
this.output.reset();
if let Some(size_addr) = *this.size_addr {
let end_addr = ready!(poll_stream_position_helper(
this.writer.as_mut(),
this.seek_state,
cx
))?;
*this.write_state =
WriteState::WritingFinalSizeHeader(size_addr, end_addr);
ready!(poll_seek_helper(
this.writer.as_mut(),
this.seek_state,
cx,
size_addr
))?;
this.output.unwritten_mut()[..2].copy_from_slice(
&((end_addr - size_addr - 2) as u16).to_le_bytes(),
);
this.output.advance(2);
continue;
}
}
}
return Poll::Ready(Ok(()));
} }
WriteState::WritingSizeHeader(_size_addr) => { WriteState::WritingOutput(next) => {
*this.write_state = WriteState::WritingOutput; if this.output.written().len() > *this.output_flushed {
continue; let n = ready!(this
.writer
.as_mut()
.poll_write(cx, &this.output.written()[*this.output_flushed..]))?;
*this.output_flushed += n;
} else {
this.output.reset();
*this.output_flushed = 0;
*this.write_state = *next;
}
} }
WriteState::EncodingInput => { WriteState::EncodingInput => {
*this.write_state = WriteState::WritingOutput; let encoder = this.zstd.get_or_insert_with(|| ZstdEncoder::new(22));
continue; encoder.encode(
} &mut PartialBuffer::new(this.input.written()),
this.output,
WriteState::FinishingCompression => { )?;
*this.write_state = WriteState::WritingOutput; let compressed = if !encoder.finish(this.output)? {
continue; std::mem::swap(this.output, this.input);
} false
} else {
WriteState::WritingFinalSizeHeader(_size_addr, end_addr) => { true
if this.output.written().len() > *this.output_flushed { };
let n = ready!( *this.zstd = None;
this.writer this.input.reset();
.as_mut() *this.write_state = WriteState::WritingOutput(Box::new(
.poll_write(cx, &this.output.written()[*this.output_flushed..]) WriteState::WritingSizeHeader(
)?; this.output.written().len() as u16
*this.output_flushed += n; | if compressed { 0 } else { 0x8000 },
continue; ),
} ));
*this.output_flushed = 0;
this.output.reset();
*this.write_state = WriteState::SeekingToEnd(end_addr);
continue;
} }
WriteState::SeekingToEnd(end_addr) => { WriteState::SeekingToEnd(end_addr) => {
@@ -466,8 +403,6 @@ impl<W: AsyncWrite + AsyncSeek> AsyncWrite for MetadataBlocksWriter<W> {
end_addr end_addr
))?; ))?;
*this.size_addr = None; *this.size_addr = None;
*this.input_flushed = 0;
*this.size = 0;
*this.write_state = WriteState::Idle; *this.write_state = WriteState::Idle;
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
@@ -486,11 +421,9 @@ impl<W: AsyncWrite + AsyncSeek> AsyncWrite for MetadataBlocksWriter<W> {
impl<W> MetadataBlocksWriter<W> { impl<W> MetadataBlocksWriter<W> {
pub fn new(writer: W) -> Self { pub fn new(writer: W) -> Self {
Self { Self {
input: [0; 8192], input: PartialBuffer::new([0; 8192]),
input_flushed: 0,
size: 0,
size_addr: None, size_addr: None,
output: PartialBuffer::new([0; 4096]), output: PartialBuffer::new([0; 8192]),
output_flushed: 0, output_flushed: 0,
zstd: None, zstd: None,
seek_state: SeekState::Idle, seek_state: SeekState::Idle,
@@ -507,11 +440,10 @@ use tokio::io::AsyncRead;
pub struct MetadataBlocksReader<R> { pub struct MetadataBlocksReader<R> {
#[pin] #[pin]
reader: R, reader: R,
size_buf: [u8; 2], size_buf: PartialBuffer<[u8; 2]>,
size_bytes_read: usize, compressed: PartialBuffer<[u8; 8192]>,
compressed: [u8; 8192],
compressed_size: usize, compressed_size: usize,
compressed_pos: usize, is_compressed: bool,
output: PartialBuffer<[u8; 8192]>, output: PartialBuffer<[u8; 8192]>,
output_pos: usize, output_pos: usize,
zstd: Option<ZstdDecoder>, zstd: Option<ZstdDecoder>,
@@ -531,11 +463,10 @@ impl<R> MetadataBlocksReader<R> {
pub fn new(reader: R) -> Self { pub fn new(reader: R) -> Self {
Self { Self {
reader, reader,
size_buf: [0; 2], size_buf: PartialBuffer::new([0; 2]),
size_bytes_read: 0, compressed: PartialBuffer::new([0; 8192]),
compressed: [0; 8192],
compressed_size: 0, compressed_size: 0,
compressed_pos: 0, is_compressed: false,
output: PartialBuffer::new([0; 8192]), output: PartialBuffer::new([0; 8192]),
output_pos: 0, output_pos: 0,
zstd: None, zstd: None,
@@ -551,11 +482,9 @@ impl<R: Read> Read for MetadataBlocksReader<R> {
loop { loop {
match self.state { match self.state {
ReadState::ReadingSize => { ReadState::ReadingSize => {
let n = self let n = self.reader.read(self.size_buf.unwritten_mut())?;
.reader
.read(&mut self.size_buf[self.size_bytes_read..])?;
if n == 0 { if n == 0 {
if self.size_bytes_read == 0 { if self.size_buf.written().is_empty() {
self.state = ReadState::Eof; self.state = ReadState::Eof;
return Ok(0); return Ok(0);
} else { } else {
@@ -566,56 +495,57 @@ impl<R: Read> Read for MetadataBlocksReader<R> {
} }
} }
self.size_bytes_read += n; self.size_buf.advance(n);
if self.size_bytes_read < 2 {
if self.size_buf.written().len() < 2 {
continue; continue;
} }
let size_header = u16::from_le_bytes(self.size_buf); let size_header = u16::from_le_bytes([
let is_compressed = (size_header & 0x8000) == 0; self.size_buf.written()[0],
let size = (size_header & 0x7FFF) as usize; self.size_buf.written()[1],
]);
self.is_compressed = (size_header & 0x8000) == 0;
self.compressed_size = (size_header & 0x7FFF) as usize;
if !is_compressed { if self.compressed_size == 0 || self.compressed_size > 8192 {
return Err(std::io::Error::new( return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData, std::io::ErrorKind::InvalidData,
"Uncompressed metadata blocks not supported", format!("Invalid metadata block size: {}", self.compressed_size),
)); ));
} }
if size == 0 || size > 8192 { self.compressed.reset();
return Err(std::io::Error::new( self.size_buf.reset();
std::io::ErrorKind::InvalidData,
format!("Invalid metadata block size: {}", size),
));
}
self.compressed_size = size;
self.compressed_pos = 0;
self.size_bytes_read = 0;
self.state = ReadState::ReadingData; self.state = ReadState::ReadingData;
continue; continue;
} }
ReadState::ReadingData => { ReadState::ReadingData => {
let n = self let n = self.reader.read(self.compressed.unwritten_mut())?;
.reader
.read(&mut self.compressed[self.compressed_pos..self.compressed_size])?;
if n == 0 { if n == 0 {
return Err(std::io::Error::new( return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof, std::io::ErrorKind::UnexpectedEof,
"Unexpected EOF reading compressed data", "Unexpected EOF reading data",
)); ));
} }
self.compressed_pos += n; self.compressed.advance(n);
if self.compressed_pos < self.compressed_size {
if !self.compressed.unwritten().is_empty() {
continue; continue;
} }
self.zstd = Some(ZstdDecoder::new());
self.output_pos = 0; self.output_pos = 0;
self.output.reset(); self.output.reset();
self.state = ReadState::Decompressing; if self.is_compressed {
self.zstd = Some(ZstdDecoder::new());
self.state = ReadState::Decompressing;
} else {
self.output
.copy_unwritten_from(&mut PartialBuffer::new(self.compressed.written()));
self.state = ReadState::Outputting;
}
continue; continue;
} }
@@ -625,7 +555,7 @@ impl<R: Read> Read for MetadataBlocksReader<R> {
continue; continue;
} }
let mut input = PartialBuffer::new(&self.compressed[..self.compressed_size]); let mut input = PartialBuffer::new(self.compressed.written());
let decoder = self.zstd.as_mut().unwrap(); let decoder = self.zstd.as_mut().unwrap();
if decoder.decode(&mut input, &mut self.output)? { if decoder.decode(&mut input, &mut self.output)? {
@@ -676,13 +606,13 @@ impl<R: AsyncRead> AsyncRead for MetadataBlocksReader<R> {
match *this.state { match *this.state {
ReadState::ReadingSize => { ReadState::ReadingSize => {
let mut read_buf = let mut read_buf = tokio::io::ReadBuf::new(this.size_buf.unwritten_mut());
tokio::io::ReadBuf::new(&mut this.size_buf[*this.size_bytes_read..]); let before = read_buf.filled().len();
ready!(this.reader.as_mut().poll_read(cx, &mut read_buf))?; ready!(this.reader.as_mut().poll_read(cx, &mut read_buf))?;
let n = read_buf.filled().len() - before;
let n = read_buf.filled().len();
if n == 0 { if n == 0 {
if *this.size_bytes_read == 0 { if this.size_buf.written().is_empty() {
*this.state = ReadState::Eof; *this.state = ReadState::Eof;
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} else { } else {
@@ -693,22 +623,16 @@ impl<R: AsyncRead> AsyncRead for MetadataBlocksReader<R> {
} }
} }
*this.size_bytes_read += n; this.size_buf.advance(n);
if *this.size_bytes_read < 2 {
if this.size_buf.written().len() < 2 {
continue; continue;
} }
let size_header = u16::from_le_bytes(*this.size_buf); let size_header = u16::from_le_bytes(*this.size_buf.written());
let is_compressed = (size_header & 0x8000) == 0; *this.is_compressed = (size_header & 0x8000) == 0;
let size = (size_header & 0x7FFF) as usize; let size = (size_header & 0x7FFF) as usize;
if !is_compressed {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Uncompressed metadata blocks not supported",
)));
}
if size == 0 || size > 8192 { if size == 0 || size > 8192 {
return Poll::Ready(Err(std::io::Error::new( return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::InvalidData, std::io::ErrorKind::InvalidData,
@@ -716,36 +640,42 @@ impl<R: AsyncRead> AsyncRead for MetadataBlocksReader<R> {
))); )));
} }
*this.compressed_size = size; this.compressed.reset();
*this.compressed_pos = 0; this.compressed.reserve(size);
*this.size_bytes_read = 0; this.size_buf.reset();
*this.state = ReadState::ReadingData; *this.state = ReadState::ReadingData;
continue; continue;
} }
ReadState::ReadingData => { ReadState::ReadingData => {
let mut read_buf = tokio::io::ReadBuf::new( let mut read_buf = tokio::io::ReadBuf::new(this.compressed.unwritten_mut());
&mut this.compressed[*this.compressed_pos..*this.compressed_size], let before = read_buf.filled().len();
);
ready!(this.reader.as_mut().poll_read(cx, &mut read_buf))?; ready!(this.reader.as_mut().poll_read(cx, &mut read_buf))?;
let n = read_buf.filled().len() - before;
let n = read_buf.filled().len();
if n == 0 { if n == 0 {
return Poll::Ready(Err(std::io::Error::new( return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof, std::io::ErrorKind::UnexpectedEof,
"Unexpected EOF reading compressed data", "Unexpected EOF reading data",
))); )));
} }
*this.compressed_pos += n; this.compressed.advance(n);
if *this.compressed_pos < *this.compressed_size {
if !this.compressed.unwritten().is_empty() {
continue; continue;
} }
*this.zstd = Some(ZstdDecoder::new());
*this.output_pos = 0; *this.output_pos = 0;
this.output.reset(); this.output.reset();
*this.state = ReadState::Decompressing; if *this.is_compressed {
*this.zstd = Some(ZstdDecoder::new());
*this.state = ReadState::Decompressing;
} else {
this.output
.copy_unwritten_from(&mut PartialBuffer::new(this.compressed.written()));
*this.state = ReadState::Outputting;
}
continue; continue;
} }
@@ -755,7 +685,7 @@ impl<R: AsyncRead> AsyncRead for MetadataBlocksReader<R> {
continue; continue;
} }
let mut input = PartialBuffer::new(&this.compressed[..*this.compressed_size]); let mut input = PartialBuffer::new(this.compressed.written());
let decoder = this.zstd.as_mut().unwrap(); let decoder = this.zstd.as_mut().unwrap();
if decoder.decode(&mut input, this.output)? { if decoder.decode(&mut input, this.output)? {