diff --git a/core/startos/Cargo.toml b/core/startos/Cargo.toml index ae4382bc2..2e6df1e3d 100644 --- a/core/startos/Cargo.toml +++ b/core/startos/Cargo.toml @@ -98,7 +98,12 @@ hex = "0.4.3" hmac = "0.12.1" http = "1.0.0" http-body-util = "0.1" -hyper-util = { version = "0.1.5", features = ["tokio", "service"] } +hyper-util = { version = "0.1.5", features = [ + "tokio", + "service", + "http1", + "http2", +] } id-pool = { version = "0.2.2", default-features = false, features = [ "serde", "u16", diff --git a/core/startos/src/net/vhost.rs b/core/startos/src/net/vhost.rs index cdd752709..9fc7c8384 100644 --- a/core/startos/src/net/vhost.rs +++ b/core/startos/src/net/vhost.rs @@ -13,6 +13,7 @@ use http::Uri; use imbl_value::InternedString; use models::ResultExt; use serde::{Deserialize, Serialize}; +use tokio::io::AsyncWriteExt; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{Mutex, RwLock}; use tokio_rustls::rustls::pki_types::{ @@ -27,7 +28,7 @@ use ts_rs::TS; use crate::db::model::Database; use crate::net::static_server::server_error; use crate::prelude::*; -use crate::util::io::BackTrackingReader; +use crate::util::io::BackTrackingIO; use crate::util::serde::MaybeUtf8String; // not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353 @@ -129,8 +130,7 @@ impl VHostServer { tracing::debug!("{e:?}"); } - let mut stream = BackTrackingReader::new(stream); - stream.start_buffering(); + let mut stream = BackTrackingIO::new(stream); let mapping = mapping.clone(); let db = db.clone(); tokio::spawn(async move { @@ -156,6 +156,7 @@ impl VHostServer { .and_then(|host| host.to_str().ok()); let uri = Uri::from_parts({ let mut parts = req.uri().to_owned().into_parts(); + parts.scheme = Some("https".parse()?); parts.authority = host.map(FromStr::from_str).transpose()?; parts })?; @@ -313,8 +314,12 @@ impl VHostServer { ) .await .with_kind(crate::ErrorKind::OpenSsl)?; + let mut accept = mid.into_stream(Arc::new(cfg)); + let io = accept.get_mut().unwrap(); + let buffered = io.stop_buffering(); + io.write_all(&buffered).await?; let mut tls_stream = - match mid.into_stream(Arc::new(cfg)).await { + match accept.await { Ok(a) => a, Err(e) => { tracing::trace!( "VHostController: failed to accept TLS connection on port {port}: {e}"); @@ -322,7 +327,6 @@ impl VHostServer { return Ok(()) } }; - tls_stream.get_mut().0.stop_buffering(); tokio::io::copy_bidirectional( &mut tls_stream, &mut target_stream, @@ -335,8 +339,12 @@ impl VHostServer { { cfg.alpn_protocols.push(proto.into()); } + let mut accept = mid.into_stream(Arc::new(cfg)); + let io = accept.get_mut().unwrap(); + let buffered = io.stop_buffering(); + io.write_all(&buffered).await?; let mut tls_stream = - match mid.into_stream(Arc::new(cfg)).await { + match accept.await { Ok(a) => a, Err(e) => { tracing::trace!( "VHostController: failed to accept TLS connection on port {port}: {e}"); @@ -344,7 +352,6 @@ impl VHostServer { return Ok(()) } }; - tls_stream.get_mut().0.stop_buffering(); tokio::io::copy_bidirectional( &mut tls_stream, &mut tcp_stream, @@ -353,8 +360,12 @@ impl VHostServer { } Err(AlpnInfo::Specified(alpn)) => { cfg.alpn_protocols = alpn.into_iter().map(|a| a.0).collect(); + let mut accept = mid.into_stream(Arc::new(cfg)); + let io = accept.get_mut().unwrap(); + let buffered = io.stop_buffering(); + io.write_all(&buffered).await?; let mut tls_stream = - match mid.into_stream(Arc::new(cfg)).await { + match accept.await { Ok(a) => a, Err(e) => { tracing::trace!( "VHostController: failed to accept TLS connection on port {port}: {e}"); @@ -362,7 +373,6 @@ impl VHostServer { return Ok(()) } }; - tls_stream.get_mut().0.stop_buffering(); tokio::io::copy_bidirectional( &mut tls_stream, &mut tcp_stream, diff --git a/core/startos/src/util/io.rs b/core/startos/src/util/io.rs index bba45fa69..6d9c8a4ff 100644 --- a/core/startos/src/util/io.rs +++ b/core/startos/src/util/io.rs @@ -411,107 +411,151 @@ impl> CursorExt for Cursor { } } -#[pin_project::pin_project] #[derive(Debug)] -pub struct BackTrackingReader { - #[pin] - reader: T, - buffer: Cursor>, - buffering: bool, +enum BTBuffer { + NotBuffering, + Buffering { read: Vec, write: Vec }, + Rewound { read: Cursor> }, } -impl BackTrackingReader { - pub fn new(reader: T) -> Self { - Self { - reader, - buffer: Cursor::new(Vec::new()), - buffering: false, - } - } - pub fn start_buffering(&mut self) { - self.buffer.set_position(0); - self.buffer.get_mut().truncate(0); - self.buffering = true; - } - pub fn stop_buffering(&mut self) { - self.buffer.set_position(0); - self.buffer.get_mut().truncate(0); - self.buffering = false; - } - pub fn rewind(&mut self) { - self.buffering = false; - } - pub fn unwrap(self) -> T { - self.reader +impl Default for BTBuffer { + fn default() -> Self { + BTBuffer::NotBuffering } } -impl AsyncRead for BackTrackingReader { +#[pin_project::pin_project] +#[derive(Debug)] +pub struct BackTrackingIO { + #[pin] + io: T, + buffer: BTBuffer, +} +impl BackTrackingIO { + pub fn new(io: T) -> Self { + Self { + io, + buffer: BTBuffer::Buffering { + read: Vec::new(), + write: Vec::new(), + }, + } + } + #[must_use] + pub fn stop_buffering(&mut self) -> Vec { + match std::mem::take(&mut self.buffer) { + BTBuffer::Buffering { write, .. } => write, + BTBuffer::NotBuffering => Vec::new(), + BTBuffer::Rewound { read } => { + self.buffer = BTBuffer::Rewound { read }; + Vec::new() + } + } + } + pub fn rewind(&mut self) -> Vec { + match std::mem::take(&mut self.buffer) { + BTBuffer::Buffering { read, write } => { + self.buffer = BTBuffer::Rewound { + read: Cursor::new(read), + }; + write + } + BTBuffer::NotBuffering => Vec::new(), + BTBuffer::Rewound { read } => { + self.buffer = BTBuffer::Rewound { read }; + Vec::new() + } + } + } + pub fn unwrap(self) -> T { + self.io + } +} + +impl AsyncRead for BackTrackingIO { fn poll_read( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let this = self.project(); - if *this.buffering { - let filled = buf.filled().len(); - let res = this.reader.poll_read(cx, buf); - this.buffer - .get_mut() - .extend_from_slice(&buf.filled()[filled..]); - res - } else { - let mut ready = false; - if (this.buffer.position() as usize) < this.buffer.get_ref().len() { - this.buffer.pure_read(buf); - ready = true; + match this.buffer { + BTBuffer::Buffering { read, .. } => { + let filled = buf.filled().len(); + let res = this.io.poll_read(cx, buf); + read.extend_from_slice(&buf.filled()[filled..]); + res } - if buf.remaining() > 0 { - match this.reader.poll_read(cx, buf) { - Poll::Pending => { - if ready { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } - a => a, + BTBuffer::NotBuffering => this.io.poll_read(cx, buf), + BTBuffer::Rewound { read } => { + let mut ready = false; + if (read.position() as usize) < read.get_ref().len() { + read.pure_read(buf); + ready = true; + } + if buf.remaining() > 0 { + match this.io.poll_read(cx, buf) { + Poll::Pending => { + if ready { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + a => a, + } + } else { + Poll::Ready(Ok(())) } - } else { - Poll::Ready(Ok(())) } } } } -impl AsyncWrite for BackTrackingReader { +impl AsyncWrite for BackTrackingIO { fn is_write_vectored(&self) -> bool { - self.reader.is_write_vectored() + self.io.is_write_vectored() } fn poll_flush( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { - self.project().reader.poll_flush(cx) + self.project().io.poll_flush(cx) } fn poll_shutdown( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { - self.project().reader.poll_shutdown(cx) + self.project().io.poll_shutdown(cx) } fn poll_write( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8], ) -> Poll> { - self.project().reader.poll_write(cx, buf) + let this = self.project(); + if let BTBuffer::Buffering { write, .. } = this.buffer { + write.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } else { + this.io.poll_write(cx, buf) + } } fn poll_write_vectored( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, bufs: &[std::io::IoSlice<'_>], ) -> Poll> { - self.project().reader.poll_write_vectored(cx, bufs) + let this = self.project(); + if let BTBuffer::Buffering { write, .. } = this.buffer { + let len = bufs.iter().map(|b| b.len()).sum(); + write.reserve(len); + for buf in bufs { + write.extend_from_slice(buf); + } + Poll::Ready(Ok(len)) + } else { + this.io.poll_write_vectored(cx, bufs) + } } }