diff --git a/core/Cargo.lock b/core/Cargo.lock index 031ec41ab..f8db0bd7e 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -4881,6 +4881,7 @@ dependencies = [ "hmac", "http 1.1.0", "http-body-util", + "hyper-util", "id-pool", "imbl", "imbl-value", @@ -4936,6 +4937,7 @@ dependencies = [ "sha2 0.10.8", "shell-words", "simple-logging", + "socket2", "sqlx", "sscanf", "ssh-key", diff --git a/core/startos/Cargo.toml b/core/startos/Cargo.toml index a8707bf65..750bc4953 100644 --- a/core/startos/Cargo.toml +++ b/core/startos/Cargo.toml @@ -97,6 +97,7 @@ hex = "0.4.3" hmac = "0.12.1" http = "1.0.0" http-body-util = "0.1" +hyper-util = { version = "0.1.5", features = ["tokio", "service"] } id-pool = { version = "0.2.2", default-features = false, features = [ "serde", "u16", @@ -159,6 +160,7 @@ serde_yaml = { package = "serde_yml", version = "0.0.10" } sha2 = "0.10.2" shell-words = "1" simple-logging = "2.0.2" +socket2 = "0.5.7" sqlx = { version = "0.7.2", features = [ "chrono", "runtime-tokio-rustls", diff --git a/core/startos/src/net/utils.rs b/core/startos/src/net/utils.rs index 6de319a5e..9cba8a0cd 100644 --- a/core/startos/src/net/utils.rs +++ b/core/startos/src/net/utils.rs @@ -112,24 +112,6 @@ pub async fn find_eth_iface() -> Result { )) } -#[pin_project::pin_project] -pub struct SingleAccept(Option); -impl SingleAccept { - pub fn new(conn: T) -> Self { - Self(Some(conn)) - } -} -// impl axum_server::accept::Accept for SingleAccept { -// type Conn = T; -// type Error = Infallible; -// fn poll_accept( -// self: std::pin::Pin<&mut Self>, -// _cx: &mut std::task::Context<'_>, -// ) -> std::task::Poll>> { -// std::task::Poll::Ready(self.project().0.take().map(Ok)) -// } -// } - pub struct TcpListeners { listeners: Vec, } diff --git a/core/startos/src/net/vhost.rs b/core/startos/src/net/vhost.rs index b4f5715ae..e6a9d5b21 100644 --- a/core/startos/src/net/vhost.rs +++ b/core/startos/src/net/vhost.rs @@ -1,10 +1,15 @@ use std::collections::BTreeMap; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; +use std::str::FromStr; use std::sync::{Arc, Weak}; use std::time::Duration; +use axum::body::Body; +use axum::extract::Request; +use axum::response::Response; use color_eyre::eyre::eyre; use helpers::NonDetachingJoinHandle; +use http::Uri; use imbl_value::InternedString; use models::ResultExt; use serde::{Deserialize, Serialize}; @@ -20,8 +25,9 @@ use tracing::instrument; use ts_rs::TS; use crate::db::model::Database; +use crate::net::static_server::server_error; use crate::prelude::*; -use crate::util::io::{BackTrackingReader, TimeoutStream}; +use crate::util::io::BackTrackingReader; use crate::util::serde::MaybeUtf8String; // not allowed: <=1024, >=32768, 5355, 5432, 9050, 6010, 9051, 5353 @@ -113,8 +119,16 @@ impl VHostServer { loop { match listener.accept().await { Ok((stream, _)) => { - let stream = - Box::pin(TimeoutStream::new(stream, Duration::from_secs(300))); + if let Err(e) = socket2::SockRef::from(&stream).set_tcp_keepalive( + &socket2::TcpKeepalive::new() + .with_time(Duration::from_secs(900)) + .with_interval(Duration::from_secs(60)) + .with_retries(5), + ) { + tracing::error!("Failed to set tcp keepalive: {e}"); + tracing::debug!("{e:?}"); + } + let mut stream = BackTrackingReader::new(stream); stream.start_buffering(); let mapping = mapping.clone(); @@ -129,38 +143,39 @@ impl VHostServer { { Ok(a) => a, Err(_) => { - // stream.rewind(); - // return hyper::server::Server::builder( - // SingleAccept::new(stream), - // ) - // .serve(make_service_fn(|_| async { - // Ok::<_, Infallible>(service_fn(|req| async move { - // let host = req - // .headers() - // .get(http::header::HOST) - // .and_then(|host| host.to_str().ok()); - // let uri = Uri::from_parts({ - // let mut parts = - // req.uri().to_owned().into_parts(); - // parts.authority = host - // .map(FromStr::from_str) - // .transpose()?; - // parts - // })?; - // Response::builder() - // .status( - // http::StatusCode::TEMPORARY_REDIRECT, - // ) - // .header( - // http::header::LOCATION, - // uri.to_string(), - // ) - // .body(Body::default()) - // })) - // })) - // .await - // .with_kind(crate::ErrorKind::Network); - todo!() + stream.rewind(); + return hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new()) + .serve_connection( + hyper_util::rt::TokioIo::new(stream), + hyper_util::service::TowerToHyperService::new(axum::Router::new().fallback( + axum::routing::method_routing::any(move |req: Request| async move { + match async move { + let host = req + .headers() + .get(http::header::HOST) + .and_then(|host| host.to_str().ok()); + let uri = Uri::from_parts({ + let mut parts = req.uri().to_owned().into_parts(); + parts.authority = host.map(FromStr::from_str).transpose()?; + parts + })?; + Response::builder() + .status(http::StatusCode::TEMPORARY_REDIRECT) + .header(http::header::LOCATION, uri.to_string()) + .body(Body::default()) + }.await { + Ok(a) => a, + Err(e) => { + tracing::warn!("Error redirecting http request on ssl port: {e}"); + tracing::error!("{e:?}"); + server_error(Error::new(e, ErrorKind::Network)) + } + } + }), + )), + ) + .await + .map_err(|e| Error::new(color_eyre::eyre::Report::msg(e), ErrorKind::Network)); } }; let target_name = diff --git a/core/startos/src/s9pk/v2/manifest.rs b/core/startos/src/s9pk/v2/manifest.rs index 77e48c126..9ae8524fa 100644 --- a/core/startos/src/s9pk/v2/manifest.rs +++ b/core/startos/src/s9pk/v2/manifest.rs @@ -146,7 +146,7 @@ impl Manifest { #[ts(export)] pub struct HardwareRequirements { #[serde(default)] - #[ts(type = "{ [key: string]: string }")] + #[ts(type = "{ [key: string]: string }")] // TODO more specific key pub device: BTreeMap, #[ts(type = "number | null")] pub ram: Option, diff --git a/core/startos/src/util/io.rs b/core/startos/src/util/io.rs index 9a6bab64b..b5748d1d9 100644 --- a/core/startos/src/util/io.rs +++ b/core/startos/src/util/io.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeSet, VecDeque}; +use std::collections::VecDeque; use std::future::Future; use std::io::Cursor; use std::os::unix::prelude::MetadataExt; @@ -706,16 +706,16 @@ impl AsyncRead for TimeoutStream { buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll> { let mut this = self.project(); - if let std::task::Poll::Ready(_) = this.sleep.as_mut().poll(cx) { + let timeout = this.sleep.as_mut().poll(cx); + let res = this.stream.poll_read(cx, buf); + if res.is_ready() { + this.sleep.reset(Instant::now() + *this.timeout); + } else if timeout.is_ready() { return std::task::Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::TimedOut, "timed out", ))); } - let res = this.stream.poll_read(cx, buf); - if res.is_ready() { - this.sleep.reset(Instant::now() + *this.timeout); - } res } } @@ -725,10 +725,16 @@ impl AsyncWrite for TimeoutStream { cx: &mut std::task::Context<'_>, buf: &[u8], ) -> std::task::Poll> { - let this = self.project(); + let mut this = self.project(); + let timeout = this.sleep.as_mut().poll(cx); let res = this.stream.poll_write(cx, buf); if res.is_ready() { this.sleep.reset(Instant::now() + *this.timeout); + } else if timeout.is_ready() { + return std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "timed out", + ))); } res } @@ -736,10 +742,16 @@ impl AsyncWrite for TimeoutStream { self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - let this = self.project(); + let mut this = self.project(); + let timeout = this.sleep.as_mut().poll(cx); let res = this.stream.poll_flush(cx); if res.is_ready() { this.sleep.reset(Instant::now() + *this.timeout); + } else if timeout.is_ready() { + return std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "timed out", + ))); } res } @@ -747,10 +759,16 @@ impl AsyncWrite for TimeoutStream { self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - let this = self.project(); + let mut this = self.project(); + let timeout = this.sleep.as_mut().poll(cx); let res = this.stream.poll_shutdown(cx); if res.is_ready() { this.sleep.reset(Instant::now() + *this.timeout); + } else if timeout.is_ready() { + return std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "timed out", + ))); } res } diff --git a/patch-db b/patch-db index c537a07ea..99076d349 160000 --- a/patch-db +++ b/patch-db @@ -1 +1 @@ -Subproject commit c537a07ea937e69b66841d903c70fd75623e5457 +Subproject commit 99076d349c6768000483ea8d47216d273586552e