diff --git a/core/startos/src/util/mod.rs b/core/startos/src/util/mod.rs index 61e1bcfb4..7c2d00471 100644 --- a/core/startos/src/util/mod.rs +++ b/core/startos/src/util/mod.rs @@ -49,7 +49,7 @@ pub mod net; pub mod rpc; pub mod rpc_client; pub mod serde; -//pub mod squashfs; +pub mod squashfs; pub mod sync; pub mod tui; diff --git a/core/startos/src/util/squashfs.rs b/core/startos/src/util/squashfs.rs index 4f63e846a..d2ab135c2 100644 --- a/core/startos/src/util/squashfs.rs +++ b/core/startos/src/util/squashfs.rs @@ -98,8 +98,7 @@ impl Visit> for Superblock { #[pin_project::pin_project] pub struct MetadataBlocksWriter { - input: [u8; 8192], - size: usize, + input: PartialBuffer<[u8; 8192]>, size_addr: Option, output: PartialBuffer<[u8; 8192]>, output_flushed: usize, @@ -123,25 +122,29 @@ enum WriteState { WritingSizeHeader(u16), WritingOutput(Box), EncodingInput, - FinishingCompression, - WritingFinalSizeHeader(u64, u64), SeekingToEnd(u64), } fn poll_seek_helper( - writer: std::pin::Pin<&mut W>, + mut writer: std::pin::Pin<&mut W>, seek_state: &mut SeekState, cx: &mut std::task::Context<'_>, pos: u64, ) -> std::task::Poll> { match *seek_state { SeekState::Idle => { - writer.start_seek(std::io::SeekFrom::Start(pos))?; + writer.as_mut().start_seek(std::io::SeekFrom::Start(pos))?; *seek_state = SeekState::Seeking(pos); - Poll::Pending + match writer.as_mut().poll_complete(cx)? { + Poll::Ready(result) => { + *seek_state = SeekState::Idle; + Poll::Ready(Ok(result)) + } + Poll::Pending => Poll::Pending, + } } SeekState::Seeking(target) if target == pos => { - let result = ready!(writer.poll_complete(cx))?; + let result = ready!(writer.as_mut().poll_complete(cx))?; *seek_state = SeekState::Idle; Poll::Ready(Ok(result)) } @@ -151,35 +154,53 @@ fn poll_seek_helper( pos, old_target ); - writer.start_seek(std::io::SeekFrom::Start(pos))?; + writer.as_mut().start_seek(std::io::SeekFrom::Start(pos))?; *seek_state = SeekState::Seeking(pos); - Poll::Pending + match writer.as_mut().poll_complete(cx)? { + Poll::Ready(result) => { + *seek_state = SeekState::Idle; + Poll::Ready(Ok(result)) + } + Poll::Pending => Poll::Pending, + } } SeekState::GettingPosition => { tracing::warn!( "poll_seek({}) called while getting stream position, canceling", pos ); - writer.start_seek(std::io::SeekFrom::Start(pos))?; + writer.as_mut().start_seek(std::io::SeekFrom::Start(pos))?; *seek_state = SeekState::Seeking(pos); - Poll::Pending + match writer.as_mut().poll_complete(cx)? { + Poll::Ready(result) => { + *seek_state = SeekState::Idle; + Poll::Ready(Ok(result)) + } + Poll::Pending => Poll::Pending, + } } } } fn poll_stream_position_helper( - writer: std::pin::Pin<&mut W>, + mut writer: std::pin::Pin<&mut W>, seek_state: &mut SeekState, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { match *seek_state { SeekState::Idle => { - writer.start_seek(std::io::SeekFrom::Current(0))?; + writer.as_mut().start_seek(std::io::SeekFrom::Current(0))?; *seek_state = SeekState::GettingPosition; - Poll::Pending + match writer.as_mut().poll_complete(cx)? { + Poll::Ready(result) => { + *seek_state = SeekState::Idle; + Poll::Ready(Ok(result)) + } + Poll::Pending => Poll::Pending, + } } SeekState::GettingPosition => { - let result = ready!(writer.poll_complete(cx))?; + let result = ready!(writer.as_mut().poll_complete(cx))?; *seek_state = SeekState::Idle; Poll::Ready(Ok(result)) } @@ -188,18 +209,22 @@ fn poll_stream_position_helper( "poll_stream_position called while seeking to {}, canceling", target ); - writer.start_seek(std::io::SeekFrom::Current(0))?; + writer.as_mut().start_seek(std::io::SeekFrom::Current(0))?; *seek_state = SeekState::GettingPosition; - Poll::Pending + match writer.as_mut().poll_complete(cx)? { + Poll::Ready(result) => { + *seek_state = SeekState::Idle; + Poll::Ready(Ok(result)) + } + Poll::Pending => Poll::Pending, + } } } } impl Write for MetadataBlocksWriter { fn write(&mut self, buf: &[u8]) -> std::io::Result { - let n = buf.len().min(self.input.len() - self.size); - self.input[self.size..self.size + n].copy_from_slice(&buf[..n]); - self.size += n; + let n = self.input.copy_unwritten_from(&mut PartialBuffer::new(buf)); if n < buf.len() { self.flush()?; } @@ -207,9 +232,9 @@ impl Write for MetadataBlocksWriter { } fn flush(&mut self) -> std::io::Result<()> { loop { - match self.write_state { + match &self.write_state { WriteState::Idle => { - if self.size == 0 { + if self.input.written().is_empty() { return Ok(()); } self.write_state = WriteState::WritingSizeHeader(0); @@ -218,12 +243,12 @@ impl Write for MetadataBlocksWriter { WriteState::WritingSizeHeader(size) => { let done = if let Some(size_addr) = self.size_addr { self.writer.seek(SeekFrom::Start(size_addr))?; - Some(size_addr + size as u64) + Some(size_addr + 2 + *size as u64) } else { self.size_addr = Some(self.writer.stream_position()?); None }; - self.output.unwritten_mut()[..2].copy_from_slice(&u16::to_le_bytes(size)[..]); + self.output.unwritten_mut()[..2].copy_from_slice(&u16::to_le_bytes(*size)[..]); self.output.advance(2); self.write_state = WriteState::WritingOutput(Box::new(if let Some(end) = done { @@ -242,80 +267,33 @@ impl Write for MetadataBlocksWriter { } else { self.output.reset(); self.output_flushed = 0; - self.write_state = *next; + self.write_state = *next.clone(); } } WriteState::EncodingInput => { let encoder = self.zstd.get_or_insert_with(|| ZstdEncoder::new(22)); - let mut input = PartialBuffer::new(&self.input[..self.size]); - while !self.output.unwritten().is_empty() && !input.unwritten().is_empty() { - encoder.encode(&mut input, &mut self.output)?; - } - while !encoder.flush(&mut self.output)? {} - while !encoder.finish(&mut self.output)? {} - if !self.output.unwritten().is_empty() { - let mut input = - PartialBuffer::new(&self.input[self.input_flushed..self.size]); - encoder.encode(&mut input, &mut self.output)?; - self.input_flushed += input.written().len(); - } - self.write_state = WriteState::WritingOutput(Box::new()); - continue; - } - - WriteState::FinishingCompression => { - if !self.output.unwritten().is_empty() { - if self.zstd.as_mut().unwrap().finish(&mut self.output)? { - self.zstd = None; - } - } - if self.output.written().len() > self.output_flushed { - self.write_state = WriteState::WritingOutput; - continue; - } - if self.zstd.is_none() && self.output.written().len() == self.output_flushed { - self.output_flushed = 0; - self.output.reset(); - let end_addr = self.writer.stream_position()?; - let size_addr = self.size_addr.ok_or_else(|| { - std::io::Error::new( - std::io::ErrorKind::InvalidData, - "size_addr not set when finishing compression", - ) - })?; - self.write_state = WriteState::WritingFinalSizeHeader(size_addr, end_addr); - continue; - } - return Ok(()); - } - - WriteState::WritingFinalSizeHeader(size_addr, end_addr) => { - if self.output.written().len() > self.output_flushed { - let n = self - .writer - .write(&self.output.written()[self.output_flushed..])?; - self.output_flushed += n; - continue; - } - self.writer.seek(std::io::SeekFrom::Start(size_addr))?; - self.output.unwritten_mut()[..2] - .copy_from_slice(&((end_addr - size_addr - 2) as u16).to_le_bytes()); - self.output.advance(2); - let n = self.writer.write(&self.output.written())?; - self.output_flushed = n; - if n == 2 { - self.output_flushed = 0; - self.output.reset(); - self.write_state = WriteState::SeekingToEnd(end_addr); - } - continue; + encoder.encode( + &mut PartialBuffer::new(&self.input.written()), + &mut self.output, + )?; + let compressed = if !encoder.finish(&mut self.output)? { + std::mem::swap(&mut self.output, &mut self.input); + false + } else { + true + }; + self.zstd = None; + self.input.reset(); + self.write_state = + WriteState::WritingOutput(Box::new(WriteState::WritingSizeHeader( + self.output.written().len() as u16 + | if compressed { 0 } else { 0x8000 }, + ))); } WriteState::SeekingToEnd(end_addr) => { - self.writer.seek(std::io::SeekFrom::Start(end_addr))?; - self.input_flushed = 0; - self.size = 0; + self.writer.seek(std::io::SeekFrom::Start(*end_addr))?; self.size_addr = None; self.write_state = WriteState::Idle; return Ok(()); @@ -332,11 +310,9 @@ impl AsyncWrite for MetadataBlocksWriter { buf: &[u8], ) -> std::task::Poll> { let this = self.as_mut().project(); - let n = buf.len().min(this.input.len() - *this.size); - this.input[*this.size..*this.size + n].copy_from_slice(&buf[..n]); - *this.size += n; + let n = this.input.copy_unwritten_from(&mut PartialBuffer::new(buf)); if n < buf.len() { - ready!(self.poll_flush(cx)?); + ready!(self.poll_flush(cx))?; } Poll::Ready(Ok(n)) } @@ -347,115 +323,76 @@ impl AsyncWrite for MetadataBlocksWriter { ) -> std::task::Poll> { loop { let mut this = self.as_mut().project(); - match *this.write_state { + match this.write_state.clone() { WriteState::Idle => { - if *this.size == 0 { + if this.input.written().is_empty() { return Poll::Ready(Ok(())); } - if this.size_addr.is_none() { + *this.write_state = WriteState::WritingSizeHeader(0); + } + + WriteState::WritingSizeHeader(size) => { + let done = if let Some(size_addr) = *this.size_addr { + ready!(poll_seek_helper( + this.writer.as_mut(), + this.seek_state, + cx, + size_addr + ))?; + Some(size_addr + 2 + size as u64) + } else { let pos = ready!(poll_stream_position_helper( this.writer.as_mut(), this.seek_state, cx ))?; *this.size_addr = Some(pos); - this.output.unwritten_mut()[..2].copy_from_slice(&[0; 2]); - this.output.advance(2); - } - *this.write_state = WriteState::WritingOutput; - continue; - } - - WriteState::WritingOutput => { - if this.output.written().len() > *this.output_flushed { - let n = ready!( - this.writer - .as_mut() - .poll_write(cx, &this.output.written()[*this.output_flushed..]) - )?; - *this.output_flushed += n; - continue; - } - if this.output.written().len() == *this.output_flushed { - *this.output_flushed = 0; - this.output.reset(); - } - if *this.input_flushed < *this.size { - if !this.output.unwritten().is_empty() { - let mut input = - PartialBuffer::new(&this.input[*this.input_flushed..*this.size]); - this.zstd - .get_or_insert_with(|| ZstdEncoder::new(22)) - .encode(&mut input, this.output)?; - *this.input_flushed += input.written().len(); - } - continue; + None + }; + this.output.unwritten_mut()[..2] + .copy_from_slice(&u16::to_le_bytes(size)[..]); + this.output.advance(2); + *this.write_state = WriteState::WritingOutput(Box::new(if let Some(end) = done { + WriteState::SeekingToEnd(end) } else { - if !this.output.unwritten().is_empty() { - if this.zstd.as_mut().unwrap().finish(this.output)? { - *this.zstd = None; - } - continue; - } - if this.zstd.is_none() - && this.output.written().len() == *this.output_flushed - { - *this.output_flushed = 0; - this.output.reset(); - if let Some(size_addr) = *this.size_addr { - let end_addr = ready!(poll_stream_position_helper( - this.writer.as_mut(), - this.seek_state, - cx - ))?; - *this.write_state = - WriteState::WritingFinalSizeHeader(size_addr, end_addr); - ready!(poll_seek_helper( - this.writer.as_mut(), - this.seek_state, - cx, - size_addr - ))?; - this.output.unwritten_mut()[..2].copy_from_slice( - &((end_addr - size_addr - 2) as u16).to_le_bytes(), - ); - this.output.advance(2); - continue; - } - } - } - return Poll::Ready(Ok(())); + WriteState::EncodingInput + })); } - WriteState::WritingSizeHeader(_size_addr) => { - *this.write_state = WriteState::WritingOutput; - continue; + WriteState::WritingOutput(next) => { + if this.output.written().len() > *this.output_flushed { + let n = ready!(this + .writer + .as_mut() + .poll_write(cx, &this.output.written()[*this.output_flushed..]))?; + *this.output_flushed += n; + } else { + this.output.reset(); + *this.output_flushed = 0; + *this.write_state = *next; + } } WriteState::EncodingInput => { - *this.write_state = WriteState::WritingOutput; - continue; - } - - WriteState::FinishingCompression => { - *this.write_state = WriteState::WritingOutput; - continue; - } - - WriteState::WritingFinalSizeHeader(_size_addr, end_addr) => { - if this.output.written().len() > *this.output_flushed { - let n = ready!( - this.writer - .as_mut() - .poll_write(cx, &this.output.written()[*this.output_flushed..]) - )?; - *this.output_flushed += n; - continue; - } - *this.output_flushed = 0; - this.output.reset(); - *this.write_state = WriteState::SeekingToEnd(end_addr); - continue; + let encoder = this.zstd.get_or_insert_with(|| ZstdEncoder::new(22)); + encoder.encode( + &mut PartialBuffer::new(this.input.written()), + this.output, + )?; + let compressed = if !encoder.finish(this.output)? { + std::mem::swap(this.output, this.input); + false + } else { + true + }; + *this.zstd = None; + this.input.reset(); + *this.write_state = WriteState::WritingOutput(Box::new( + WriteState::WritingSizeHeader( + this.output.written().len() as u16 + | if compressed { 0 } else { 0x8000 }, + ), + )); } WriteState::SeekingToEnd(end_addr) => { @@ -466,8 +403,6 @@ impl AsyncWrite for MetadataBlocksWriter { end_addr ))?; *this.size_addr = None; - *this.input_flushed = 0; - *this.size = 0; *this.write_state = WriteState::Idle; return Poll::Ready(Ok(())); } @@ -486,11 +421,9 @@ impl AsyncWrite for MetadataBlocksWriter { impl MetadataBlocksWriter { pub fn new(writer: W) -> Self { Self { - input: [0; 8192], - input_flushed: 0, - size: 0, + input: PartialBuffer::new([0; 8192]), size_addr: None, - output: PartialBuffer::new([0; 4096]), + output: PartialBuffer::new([0; 8192]), output_flushed: 0, zstd: None, seek_state: SeekState::Idle, @@ -507,11 +440,10 @@ use tokio::io::AsyncRead; pub struct MetadataBlocksReader { #[pin] reader: R, - size_buf: [u8; 2], - size_bytes_read: usize, - compressed: [u8; 8192], + size_buf: PartialBuffer<[u8; 2]>, + compressed: PartialBuffer<[u8; 8192]>, compressed_size: usize, - compressed_pos: usize, + is_compressed: bool, output: PartialBuffer<[u8; 8192]>, output_pos: usize, zstd: Option, @@ -531,11 +463,10 @@ impl MetadataBlocksReader { pub fn new(reader: R) -> Self { Self { reader, - size_buf: [0; 2], - size_bytes_read: 0, - compressed: [0; 8192], + size_buf: PartialBuffer::new([0; 2]), + compressed: PartialBuffer::new([0; 8192]), compressed_size: 0, - compressed_pos: 0, + is_compressed: false, output: PartialBuffer::new([0; 8192]), output_pos: 0, zstd: None, @@ -551,11 +482,9 @@ impl Read for MetadataBlocksReader { loop { match self.state { ReadState::ReadingSize => { - let n = self - .reader - .read(&mut self.size_buf[self.size_bytes_read..])?; + let n = self.reader.read(self.size_buf.unwritten_mut())?; if n == 0 { - if self.size_bytes_read == 0 { + if self.size_buf.written().is_empty() { self.state = ReadState::Eof; return Ok(0); } else { @@ -566,56 +495,57 @@ impl Read for MetadataBlocksReader { } } - self.size_bytes_read += n; - if self.size_bytes_read < 2 { + self.size_buf.advance(n); + + if self.size_buf.written().len() < 2 { continue; } - let size_header = u16::from_le_bytes(self.size_buf); - let is_compressed = (size_header & 0x8000) == 0; - let size = (size_header & 0x7FFF) as usize; + let size_header = u16::from_le_bytes([ + self.size_buf.written()[0], + self.size_buf.written()[1], + ]); + self.is_compressed = (size_header & 0x8000) == 0; + self.compressed_size = (size_header & 0x7FFF) as usize; - if !is_compressed { + if self.compressed_size == 0 || self.compressed_size > 8192 { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, - "Uncompressed metadata blocks not supported", + format!("Invalid metadata block size: {}", self.compressed_size), )); } - if size == 0 || size > 8192 { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Invalid metadata block size: {}", size), - )); - } - - self.compressed_size = size; - self.compressed_pos = 0; - self.size_bytes_read = 0; + self.compressed.reset(); + self.size_buf.reset(); self.state = ReadState::ReadingData; continue; } ReadState::ReadingData => { - let n = self - .reader - .read(&mut self.compressed[self.compressed_pos..self.compressed_size])?; + let n = self.reader.read(self.compressed.unwritten_mut())?; if n == 0 { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, - "Unexpected EOF reading compressed data", + "Unexpected EOF reading data", )); } - self.compressed_pos += n; - if self.compressed_pos < self.compressed_size { + self.compressed.advance(n); + + if !self.compressed.unwritten().is_empty() { continue; } - self.zstd = Some(ZstdDecoder::new()); self.output_pos = 0; self.output.reset(); - self.state = ReadState::Decompressing; + if self.is_compressed { + self.zstd = Some(ZstdDecoder::new()); + self.state = ReadState::Decompressing; + } else { + self.output + .copy_unwritten_from(&mut PartialBuffer::new(self.compressed.written())); + self.state = ReadState::Outputting; + } continue; } @@ -625,7 +555,7 @@ impl Read for MetadataBlocksReader { continue; } - let mut input = PartialBuffer::new(&self.compressed[..self.compressed_size]); + let mut input = PartialBuffer::new(self.compressed.written()); let decoder = self.zstd.as_mut().unwrap(); if decoder.decode(&mut input, &mut self.output)? { @@ -676,13 +606,13 @@ impl AsyncRead for MetadataBlocksReader { match *this.state { ReadState::ReadingSize => { - let mut read_buf = - tokio::io::ReadBuf::new(&mut this.size_buf[*this.size_bytes_read..]); + let mut read_buf = tokio::io::ReadBuf::new(this.size_buf.unwritten_mut()); + let before = read_buf.filled().len(); ready!(this.reader.as_mut().poll_read(cx, &mut read_buf))?; + let n = read_buf.filled().len() - before; - let n = read_buf.filled().len(); if n == 0 { - if *this.size_bytes_read == 0 { + if this.size_buf.written().is_empty() { *this.state = ReadState::Eof; return Poll::Ready(Ok(())); } else { @@ -693,22 +623,16 @@ impl AsyncRead for MetadataBlocksReader { } } - *this.size_bytes_read += n; - if *this.size_bytes_read < 2 { + this.size_buf.advance(n); + + if this.size_buf.written().len() < 2 { continue; } - let size_header = u16::from_le_bytes(*this.size_buf); - let is_compressed = (size_header & 0x8000) == 0; + let size_header = u16::from_le_bytes(*this.size_buf.written()); + *this.is_compressed = (size_header & 0x8000) == 0; let size = (size_header & 0x7FFF) as usize; - if !is_compressed { - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "Uncompressed metadata blocks not supported", - ))); - } - if size == 0 || size > 8192 { return Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::InvalidData, @@ -716,36 +640,42 @@ impl AsyncRead for MetadataBlocksReader { ))); } - *this.compressed_size = size; - *this.compressed_pos = 0; - *this.size_bytes_read = 0; + this.compressed.reset(); + this.compressed.reserve(size); + this.size_buf.reset(); *this.state = ReadState::ReadingData; continue; } ReadState::ReadingData => { - let mut read_buf = tokio::io::ReadBuf::new( - &mut this.compressed[*this.compressed_pos..*this.compressed_size], - ); + let mut read_buf = tokio::io::ReadBuf::new(this.compressed.unwritten_mut()); + let before = read_buf.filled().len(); ready!(this.reader.as_mut().poll_read(cx, &mut read_buf))?; + let n = read_buf.filled().len() - before; - let n = read_buf.filled().len(); if n == 0 { return Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, - "Unexpected EOF reading compressed data", + "Unexpected EOF reading data", ))); } - *this.compressed_pos += n; - if *this.compressed_pos < *this.compressed_size { + this.compressed.advance(n); + + if !this.compressed.unwritten().is_empty() { continue; } - *this.zstd = Some(ZstdDecoder::new()); *this.output_pos = 0; this.output.reset(); - *this.state = ReadState::Decompressing; + if *this.is_compressed { + *this.zstd = Some(ZstdDecoder::new()); + *this.state = ReadState::Decompressing; + } else { + this.output + .copy_unwritten_from(&mut PartialBuffer::new(this.compressed.written())); + *this.state = ReadState::Outputting; + } continue; } @@ -755,7 +685,7 @@ impl AsyncRead for MetadataBlocksReader { continue; } - let mut input = PartialBuffer::new(&this.compressed[..*this.compressed_size]); + let mut input = PartialBuffer::new(this.compressed.written()); let decoder = this.zstd.as_mut().unwrap(); if decoder.decode(&mut input, this.output)? {