fix https redirect

This commit is contained in:
Aiden McClelland
2024-07-25 14:34:30 -06:00
parent a9373d9779
commit 419e3f7f2b
3 changed files with 129 additions and 70 deletions

View File

@@ -98,7 +98,12 @@ hex = "0.4.3"
hmac = "0.12.1" hmac = "0.12.1"
http = "1.0.0" http = "1.0.0"
http-body-util = "0.1" http-body-util = "0.1"
hyper-util = { version = "0.1.5", features = ["tokio", "service"] } hyper-util = { version = "0.1.5", features = [
"tokio",
"service",
"http1",
"http2",
] }
id-pool = { version = "0.2.2", default-features = false, features = [ id-pool = { version = "0.2.2", default-features = false, features = [
"serde", "serde",
"u16", "u16",

View File

@@ -13,6 +13,7 @@ use http::Uri;
use imbl_value::InternedString; use imbl_value::InternedString;
use models::ResultExt; use models::ResultExt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{Mutex, RwLock}; use tokio::sync::{Mutex, RwLock};
use tokio_rustls::rustls::pki_types::{ use tokio_rustls::rustls::pki_types::{
@@ -27,7 +28,7 @@ use ts_rs::TS;
use crate::db::model::Database; use crate::db::model::Database;
use crate::net::static_server::server_error; use crate::net::static_server::server_error;
use crate::prelude::*; use crate::prelude::*;
use crate::util::io::BackTrackingReader; use crate::util::io::BackTrackingIO;
use crate::util::serde::MaybeUtf8String; use crate::util::serde::MaybeUtf8String;
// not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353 // not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353
@@ -129,8 +130,7 @@ impl VHostServer {
tracing::debug!("{e:?}"); tracing::debug!("{e:?}");
} }
let mut stream = BackTrackingReader::new(stream); let mut stream = BackTrackingIO::new(stream);
stream.start_buffering();
let mapping = mapping.clone(); let mapping = mapping.clone();
let db = db.clone(); let db = db.clone();
tokio::spawn(async move { tokio::spawn(async move {
@@ -156,6 +156,7 @@ impl VHostServer {
.and_then(|host| host.to_str().ok()); .and_then(|host| host.to_str().ok());
let uri = Uri::from_parts({ let uri = Uri::from_parts({
let mut parts = req.uri().to_owned().into_parts(); let mut parts = req.uri().to_owned().into_parts();
parts.scheme = Some("https".parse()?);
parts.authority = host.map(FromStr::from_str).transpose()?; parts.authority = host.map(FromStr::from_str).transpose()?;
parts parts
})?; })?;
@@ -313,8 +314,12 @@ impl VHostServer {
) )
.await .await
.with_kind(crate::ErrorKind::OpenSsl)?; .with_kind(crate::ErrorKind::OpenSsl)?;
let mut accept = mid.into_stream(Arc::new(cfg));
let io = accept.get_mut().unwrap();
let buffered = io.stop_buffering();
io.write_all(&buffered).await?;
let mut tls_stream = let mut tls_stream =
match mid.into_stream(Arc::new(cfg)).await { match accept.await {
Ok(a) => a, Ok(a) => a,
Err(e) => { Err(e) => {
tracing::trace!( "VHostController: failed to accept TLS connection on port {port}: {e}"); tracing::trace!( "VHostController: failed to accept TLS connection on port {port}: {e}");
@@ -322,7 +327,6 @@ impl VHostServer {
return Ok(()) return Ok(())
} }
}; };
tls_stream.get_mut().0.stop_buffering();
tokio::io::copy_bidirectional( tokio::io::copy_bidirectional(
&mut tls_stream, &mut tls_stream,
&mut target_stream, &mut target_stream,
@@ -335,8 +339,12 @@ impl VHostServer {
{ {
cfg.alpn_protocols.push(proto.into()); cfg.alpn_protocols.push(proto.into());
} }
let mut accept = mid.into_stream(Arc::new(cfg));
let io = accept.get_mut().unwrap();
let buffered = io.stop_buffering();
io.write_all(&buffered).await?;
let mut tls_stream = let mut tls_stream =
match mid.into_stream(Arc::new(cfg)).await { match accept.await {
Ok(a) => a, Ok(a) => a,
Err(e) => { Err(e) => {
tracing::trace!( "VHostController: failed to accept TLS connection on port {port}: {e}"); tracing::trace!( "VHostController: failed to accept TLS connection on port {port}: {e}");
@@ -344,7 +352,6 @@ impl VHostServer {
return Ok(()) return Ok(())
} }
}; };
tls_stream.get_mut().0.stop_buffering();
tokio::io::copy_bidirectional( tokio::io::copy_bidirectional(
&mut tls_stream, &mut tls_stream,
&mut tcp_stream, &mut tcp_stream,
@@ -353,8 +360,12 @@ impl VHostServer {
} }
Err(AlpnInfo::Specified(alpn)) => { Err(AlpnInfo::Specified(alpn)) => {
cfg.alpn_protocols = alpn.into_iter().map(|a| a.0).collect(); cfg.alpn_protocols = alpn.into_iter().map(|a| a.0).collect();
let mut accept = mid.into_stream(Arc::new(cfg));
let io = accept.get_mut().unwrap();
let buffered = io.stop_buffering();
io.write_all(&buffered).await?;
let mut tls_stream = let mut tls_stream =
match mid.into_stream(Arc::new(cfg)).await { match accept.await {
Ok(a) => a, Ok(a) => a,
Err(e) => { Err(e) => {
tracing::trace!( "VHostController: failed to accept TLS connection on port {port}: {e}"); tracing::trace!( "VHostController: failed to accept TLS connection on port {port}: {e}");
@@ -362,7 +373,6 @@ impl VHostServer {
return Ok(()) return Ok(())
} }
}; };
tls_stream.get_mut().0.stop_buffering();
tokio::io::copy_bidirectional( tokio::io::copy_bidirectional(
&mut tls_stream, &mut tls_stream,
&mut tcp_stream, &mut tcp_stream,

View File

@@ -411,107 +411,151 @@ impl<T: AsRef<[u8]>> CursorExt for Cursor<T> {
} }
} }
#[pin_project::pin_project]
#[derive(Debug)] #[derive(Debug)]
pub struct BackTrackingReader<T> { enum BTBuffer {
#[pin] NotBuffering,
reader: T, Buffering { read: Vec<u8>, write: Vec<u8> },
buffer: Cursor<Vec<u8>>, Rewound { read: Cursor<Vec<u8>> },
buffering: bool,
} }
impl<T> BackTrackingReader<T> { impl Default for BTBuffer {
pub fn new(reader: T) -> Self { fn default() -> Self {
Self { BTBuffer::NotBuffering
reader,
buffer: Cursor::new(Vec::new()),
buffering: false,
}
}
pub fn start_buffering(&mut self) {
self.buffer.set_position(0);
self.buffer.get_mut().truncate(0);
self.buffering = true;
}
pub fn stop_buffering(&mut self) {
self.buffer.set_position(0);
self.buffer.get_mut().truncate(0);
self.buffering = false;
}
pub fn rewind(&mut self) {
self.buffering = false;
}
pub fn unwrap(self) -> T {
self.reader
} }
} }
impl<T: AsyncRead> AsyncRead for BackTrackingReader<T> { #[pin_project::pin_project]
#[derive(Debug)]
pub struct BackTrackingIO<T> {
#[pin]
io: T,
buffer: BTBuffer,
}
impl<T> BackTrackingIO<T> {
pub fn new(io: T) -> Self {
Self {
io,
buffer: BTBuffer::Buffering {
read: Vec::new(),
write: Vec::new(),
},
}
}
#[must_use]
pub fn stop_buffering(&mut self) -> Vec<u8> {
match std::mem::take(&mut self.buffer) {
BTBuffer::Buffering { write, .. } => write,
BTBuffer::NotBuffering => Vec::new(),
BTBuffer::Rewound { read } => {
self.buffer = BTBuffer::Rewound { read };
Vec::new()
}
}
}
pub fn rewind(&mut self) -> Vec<u8> {
match std::mem::take(&mut self.buffer) {
BTBuffer::Buffering { read, write } => {
self.buffer = BTBuffer::Rewound {
read: Cursor::new(read),
};
write
}
BTBuffer::NotBuffering => Vec::new(),
BTBuffer::Rewound { read } => {
self.buffer = BTBuffer::Rewound { read };
Vec::new()
}
}
}
pub fn unwrap(self) -> T {
self.io
}
}
impl<T: AsyncRead> AsyncRead for BackTrackingIO<T> {
fn poll_read( fn poll_read(
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>, buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> { ) -> Poll<std::io::Result<()>> {
let this = self.project(); let this = self.project();
if *this.buffering { match this.buffer {
let filled = buf.filled().len(); BTBuffer::Buffering { read, .. } => {
let res = this.reader.poll_read(cx, buf); let filled = buf.filled().len();
this.buffer let res = this.io.poll_read(cx, buf);
.get_mut() read.extend_from_slice(&buf.filled()[filled..]);
.extend_from_slice(&buf.filled()[filled..]); res
res
} else {
let mut ready = false;
if (this.buffer.position() as usize) < this.buffer.get_ref().len() {
this.buffer.pure_read(buf);
ready = true;
} }
if buf.remaining() > 0 { BTBuffer::NotBuffering => this.io.poll_read(cx, buf),
match this.reader.poll_read(cx, buf) { BTBuffer::Rewound { read } => {
Poll::Pending => { let mut ready = false;
if ready { if (read.position() as usize) < read.get_ref().len() {
Poll::Ready(Ok(())) read.pure_read(buf);
} else { ready = true;
Poll::Pending }
} if buf.remaining() > 0 {
} match this.io.poll_read(cx, buf) {
a => a, Poll::Pending => {
if ready {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
a => a,
}
} else {
Poll::Ready(Ok(()))
} }
} else {
Poll::Ready(Ok(()))
} }
} }
} }
} }
impl<T: AsyncWrite> AsyncWrite for BackTrackingReader<T> { impl<T: AsyncWrite> AsyncWrite for BackTrackingIO<T> {
fn is_write_vectored(&self) -> bool { fn is_write_vectored(&self) -> bool {
self.reader.is_write_vectored() self.io.is_write_vectored()
} }
fn poll_flush( fn poll_flush(
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> { ) -> Poll<Result<(), std::io::Error>> {
self.project().reader.poll_flush(cx) self.project().io.poll_flush(cx)
} }
fn poll_shutdown( fn poll_shutdown(
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> { ) -> Poll<Result<(), std::io::Error>> {
self.project().reader.poll_shutdown(cx) self.project().io.poll_shutdown(cx)
} }
fn poll_write( fn poll_write(
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
buf: &[u8], buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> { ) -> Poll<Result<usize, std::io::Error>> {
self.project().reader.poll_write(cx, buf) let this = self.project();
if let BTBuffer::Buffering { write, .. } = this.buffer {
write.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
} else {
this.io.poll_write(cx, buf)
}
} }
fn poll_write_vectored( fn poll_write_vectored(
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>], bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> { ) -> Poll<Result<usize, std::io::Error>> {
self.project().reader.poll_write_vectored(cx, bufs) let this = self.project();
if let BTBuffer::Buffering { write, .. } = this.buffer {
let len = bufs.iter().map(|b| b.len()).sum();
write.reserve(len);
for buf in bufs {
write.extend_from_slice(buf);
}
Poll::Ready(Ok(len))
} else {
this.io.poll_write_vectored(cx, bufs)
}
} }
} }