Files
start-os/core/startos/src/net/web_server.rs
2025-12-18 04:22:56 -07:00

565 lines
17 KiB
Rust

use core::fmt;
use std::any::Any;
use std::collections::BTreeMap;
use std::future::Future;
use std::net::SocketAddr;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Poll, ready};
use std::time::Duration;
use axum::Router;
use futures::future::Either;
use futures::{FutureExt, TryFutureExt};
use http::Extensions;
use hyper_util::rt::{TokioIo, TokioTimer};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use visit_rs::{Visit, VisitFields, Visitor};
use crate::net::static_server::{UiContext, ui_router};
use crate::prelude::*;
use crate::util::actor::background::BackgroundJobQueue;
use crate::util::future::NonDetachingJoinHandle;
use crate::util::io::ReadWriter;
use crate::util::sync::{SyncRwLock, Watch};
pub type AcceptStream = Pin<Box<dyn ReadWriter + Send + 'static>>;
pub trait MetadataVisitor: Visitor<Result = ()> {
fn visit<M: Clone + Send + Sync + 'static>(&mut self, metadata: &M) -> Self::Result;
}
pub struct ExtensionVisitor<'a>(&'a mut Extensions);
impl<'a> Visitor for ExtensionVisitor<'a> {
type Result = ();
}
impl<'a> MetadataVisitor for ExtensionVisitor<'a> {
fn visit<M: Clone + Send + Sync + 'static>(&mut self, metadata: &M) -> Self::Result {
self.0.insert(metadata.clone());
}
}
impl<'a> Visit<ExtensionVisitor<'a>>
for Box<dyn for<'x> Visit<ExtensionVisitor<'x>> + Send + Sync + 'static>
{
fn visit(
&self,
visitor: &mut ExtensionVisitor<'a>,
) -> <ExtensionVisitor<'a> as Visitor>::Result {
(&**self).visit(visitor)
}
}
pub struct ExtractVisitor<T>(Option<T>);
impl<T> Visitor for ExtractVisitor<T> {
type Result = ();
}
impl<T: Clone + Send + Sync + 'static> MetadataVisitor for ExtractVisitor<T> {
fn visit<M: Clone + Send + Sync + 'static>(&mut self, metadata: &M) -> Self::Result {
if let Some(matching) = (metadata as &dyn Any).downcast_ref::<T>() {
self.0 = Some(matching.clone());
}
}
}
pub fn extract<
T: Clone + Send + Sync + 'static,
M: Visit<ExtractVisitor<T>> + Clone + Send + Sync + 'static,
>(
metadata: &M,
) -> Option<T> {
let mut visitor = ExtractVisitor(None);
metadata.visit(&mut visitor);
visitor.0
}
#[derive(Clone, Copy, Debug)]
pub struct TcpMetadata {
pub peer_addr: SocketAddr,
pub local_addr: SocketAddr,
}
impl<V: MetadataVisitor> Visit<V> for TcpMetadata {
fn visit(&self, visitor: &mut V) -> <V as visit_rs::Visitor>::Result {
visitor.visit(self)
}
}
pub trait Accept {
type Metadata: fmt::Debug;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>>;
fn into_dyn(self) -> DynAccept
where
Self: Sized + Send + Sync + 'static,
for<'a> Self::Metadata: Visit<ExtensionVisitor<'a>> + Send + Sync + 'static,
{
DynAccept::new(self)
}
}
impl Accept for TcpListener {
type Metadata = TcpMetadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
if let Poll::Ready((stream, peer_addr)) = TcpListener::poll_accept(self, cx)? {
if let Err(e) = socket2::SockRef::from(&stream).set_tcp_keepalive(
&socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(900))
.with_interval(Duration::from_secs(60))
.with_retries(5),
) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
return Poll::Ready(Ok((
TcpMetadata {
local_addr: self.local_addr()?,
peer_addr,
},
Box::pin(stream),
)));
}
Poll::Pending
}
}
impl<A> Accept for Vec<A>
where
A: Accept,
{
type Metadata = A::Metadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
for listener in self {
if let Poll::Ready(accepted) = listener.poll_accept(cx)? {
return Poll::Ready(Ok(accepted));
}
}
Poll::Pending
}
}
#[derive(Debug, Clone, VisitFields)]
pub struct MapListenerMetadata<K, M> {
pub inner: M,
pub key: K,
}
impl<K, M, V> Visit<V> for MapListenerMetadata<K, M>
where
V: MetadataVisitor,
K: Visit<V>,
M: Visit<V>,
{
fn visit(&self, visitor: &mut V) -> <V as Visitor>::Result {
self.visit_fields(visitor).collect()
}
}
impl<K, A> Accept for BTreeMap<K, A>
where
K: Clone + fmt::Debug,
A: Accept,
{
type Metadata = MapListenerMetadata<K, A::Metadata>;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
for (key, listener) in self {
if let Poll::Ready((metadata, stream)) = listener.poll_accept(cx)? {
return Poll::Ready(Ok((
MapListenerMetadata {
inner: metadata,
key: key.clone(),
},
stream,
)));
}
}
Poll::Pending
}
}
impl<A, B> Accept for Either<A, B>
where
A: Accept,
B: Accept<Metadata = A::Metadata>,
{
type Metadata = A::Metadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
match self {
Either::Left(a) => a.poll_accept(cx),
Either::Right(b) => b.poll_accept(cx),
}
}
}
impl<A: Accept> Accept for Option<A> {
type Metadata = A::Metadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
match self {
None => Poll::Pending,
Some(a) => a.poll_accept(cx),
}
}
}
trait DynAcceptT: Send + Sync {
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(DynMetadata, AcceptStream), Error>>;
}
impl<A> DynAcceptT for A
where
A: Accept + Send + Sync,
<A as Accept>::Metadata: DynMetadataT + 'static,
{
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(DynMetadata, AcceptStream), Error>> {
let (metadata, stream) = ready!(Accept::poll_accept(self, cx)?);
Poll::Ready(Ok((DynMetadata(Box::new(metadata)), stream)))
}
}
pub struct DynAccept(Box<dyn DynAcceptT>);
trait DynMetadataT: for<'a> Visit<ExtensionVisitor<'a>> + fmt::Debug + Send + Sync {}
impl<T> DynMetadataT for T where for<'a> T: Visit<ExtensionVisitor<'a>> + fmt::Debug + Send + Sync {}
#[derive(Debug)]
pub struct DynMetadata(Box<dyn DynMetadataT>);
impl<'a> Visit<ExtensionVisitor<'a>> for DynMetadata {
fn visit(
&self,
visitor: &mut ExtensionVisitor<'a>,
) -> <ExtensionVisitor<'a> as Visitor>::Result {
self.0.visit(visitor)
}
}
impl Accept for DynAccept {
type Metadata = DynMetadata;
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
DynAcceptT::poll_accept(&mut *self.0, cx)
}
fn into_dyn(self) -> DynAccept
where
Self: Sized,
for<'a> Self::Metadata: Visit<ExtensionVisitor<'a>> + Send + Sync + 'static,
{
self
}
}
impl DynAccept {
pub fn new<A>(accept: A) -> Self
where
A: Accept + Send + Sync + 'static,
for<'a> <A as Accept>::Metadata: Visit<ExtensionVisitor<'a>> + Send + Sync + 'static,
{
Self(Box::new(accept))
}
}
#[pin_project::pin_project]
pub struct Acceptor<A: Accept> {
acceptor: Watch<A>,
}
impl<A: Accept + Send + Sync + 'static> Acceptor<A> {
pub fn new(acceptor: A) -> Self {
Self {
acceptor: Watch::new(acceptor),
}
}
fn poll_changed(&mut self, cx: &mut std::task::Context<'_>) -> Poll<()> {
self.acceptor.poll_changed(cx)
}
fn poll_accept(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(A::Metadata, AcceptStream), Error>> {
while self.poll_changed(cx).is_ready() {}
self.acceptor.peek_mut(|a| a.poll_accept(cx))
}
async fn accept(&mut self) -> Result<(A::Metadata, AcceptStream), Error> {
std::future::poll_fn(|cx| self.poll_accept(cx)).await
}
}
impl Acceptor<Vec<TcpListener>> {
pub async fn bind(listen: impl IntoIterator<Item = SocketAddr>) -> Result<Self, Error> {
Ok(Self::new(
futures::future::try_join_all(listen.into_iter().map(TcpListener::bind)).await?,
))
}
}
impl Acceptor<Vec<DynAccept>> {
pub async fn bind_dyn(listen: impl IntoIterator<Item = SocketAddr>) -> Result<Self, Error> {
Ok(Self::new(
futures::future::try_join_all(
listen
.into_iter()
.map(TcpListener::bind)
.map(|f| f.map_ok(DynAccept::new)),
)
.await?,
))
}
}
impl<K> Acceptor<BTreeMap<K, TcpListener>>
where
K: Ord + Clone + fmt::Debug + Send + Sync + 'static,
{
pub async fn bind_map(
listen: impl IntoIterator<Item = (K, SocketAddr)>,
) -> Result<Self, Error> {
Ok(Self::new(
futures::future::try_join_all(listen.into_iter().map(|(key, addr)| async move {
Ok::<_, Error>((
key,
TcpListener::bind(addr)
.await
.with_kind(ErrorKind::Network)?,
))
}))
.await?
.into_iter()
.collect(),
))
}
}
impl<K> Acceptor<BTreeMap<K, DynAccept>>
where
K: Ord + Clone + fmt::Debug + Send + Sync + 'static,
{
pub async fn bind_map_dyn(
listen: impl IntoIterator<Item = (K, SocketAddr)>,
) -> Result<Self, Error> {
Ok(Self::new(
futures::future::try_join_all(listen.into_iter().map(|(key, addr)| async move {
Ok::<_, Error>((
key,
TcpListener::bind(addr)
.await
.with_kind(ErrorKind::Network)?,
))
}))
.await?
.into_iter()
.map(|(key, listener)| (key, listener.into_dyn()))
.collect(),
))
}
}
pub struct WebServerAcceptorSetter<A: Accept> {
acceptor: Watch<A>,
}
impl<A, B> WebServerAcceptorSetter<Option<Either<A, B>>>
where
A: Accept,
B: Accept<Metadata = A::Metadata>,
{
pub fn try_upgrade<F: FnOnce(A) -> Result<B, Error>>(&self, f: F) -> Result<(), Error> {
let mut res = Ok(());
self.acceptor.send_modify(|a| {
*a = match a.take() {
Some(Either::Left(a)) => match f(a) {
Ok(b) => Some(Either::Right(b)),
Err(e) => {
res = Err(e);
None
}
},
x => x,
}
});
res
}
}
impl<A: Accept> Deref for WebServerAcceptorSetter<A> {
type Target = Watch<A>;
fn deref(&self) -> &Self::Target {
&self.acceptor
}
}
pub struct WebServer<A: Accept> {
shutdown: oneshot::Sender<()>,
router: Watch<Router>,
acceptor: Watch<A>,
thread: NonDetachingJoinHandle<()>,
}
impl<A> WebServer<A>
where
A: Accept + Send + Sync + 'static,
for<'a> A::Metadata: Visit<ExtensionVisitor<'a>> + Send + Sync + 'static,
{
pub fn acceptor_setter(&self) -> WebServerAcceptorSetter<A> {
WebServerAcceptorSetter {
acceptor: self.acceptor.clone(),
}
}
pub fn new(mut acceptor: Acceptor<A>, router: Router) -> Self {
let acceptor_send = acceptor.acceptor.clone();
let router = Watch::new(router);
let service = router.clone_unseen();
let (shutdown, shutdown_recv) = oneshot::channel();
let thread = NonDetachingJoinHandle::from(tokio::spawn(async move {
#[derive(Clone)]
struct QueueRunner {
queue: Arc<SyncRwLock<Option<BackgroundJobQueue>>>,
}
impl<Fut> hyper::rt::Executor<Fut> for QueueRunner
where
Fut: Future + Send + 'static,
{
fn execute(&self, fut: Fut) {
self.queue.peek(|q| {
if let Some(q) = q {
q.add_job(fut);
} else {
tracing::warn!("job queued after shutdown");
}
})
}
}
struct SwappableRouter<M> {
router: Watch<Router>,
metadata: M,
}
impl<M: for<'a> Visit<ExtensionVisitor<'a>> + Send + Sync + 'static>
hyper::service::Service<hyper::Request<hyper::body::Incoming>>
for SwappableRouter<M>
{
type Response = <Router as tower_service::Service<
hyper::Request<hyper::body::Incoming>,
>>::Response;
type Error = <Router as tower_service::Service<
hyper::Request<hyper::body::Incoming>,
>>::Error;
type Future = <Router as tower_service::Service<
hyper::Request<hyper::body::Incoming>,
>>::Future;
fn call(&self, mut req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
use tower_service::Service;
self.metadata
.visit(&mut ExtensionVisitor(req.extensions_mut()));
self.router.read().call(req)
}
}
let queue_cell = Arc::new(SyncRwLock::new(None));
let graceful = hyper_util::server::graceful::GracefulShutdown::new();
let mut server = hyper_util::server::conn::auto::Builder::new(QueueRunner {
queue: queue_cell.clone(),
});
server
.http1()
.timer(TokioTimer::new())
.title_case_headers(true)
.preserve_header_case(true)
.http2()
.timer(TokioTimer::new())
.enable_connect_protocol()
.keep_alive_interval(Duration::from_secs(25))
.keep_alive_timeout(Duration::from_secs(300));
let (queue, mut runner) = BackgroundJobQueue::new();
queue_cell.replace(Some(queue.clone()));
let handler = async {
loop {
let mut err = None;
for _ in 0..5 {
if let Err(e) = async {
let (metadata, stream) = acceptor.accept().await?;
queue.add_job(
graceful.watch(
server
.serve_connection_with_upgrades(
TokioIo::new(stream),
SwappableRouter {
router: service.clone(),
metadata,
},
)
.into_owned(),
),
);
Ok::<_, Error>(())
}
.await
{
err = Some(e);
tokio::time::sleep(Duration::from_millis(100)).await;
} else {
break;
}
}
if let Some(e) = err {
tracing::error!("Error accepting HTTP connection: {e}");
tracing::debug!("{e:?}");
}
}
}
.boxed();
tokio::select! {
_ = shutdown_recv => (),
_ = handler => (),
_ = &mut runner => (),
}
drop(queue);
drop(queue_cell.replace(None));
if !runner.is_empty() {
tokio::time::timeout(Duration::from_secs(60), runner)
.await
.log_err();
}
}));
Self {
shutdown,
router,
thread,
acceptor: acceptor_send,
}
}
pub async fn shutdown(self) {
self.shutdown.send(()).unwrap_or_default();
self.thread.await.unwrap()
}
pub fn serve_router(&mut self, router: Router) {
self.router.send(router)
}
pub fn serve_ui_for<C: UiContext>(&mut self, ctx: C) {
self.serve_router(ui_router(ctx))
}
}