use core::fmt; use std::any::Any; use std::collections::BTreeMap; use std::future::Future; use std::net::SocketAddr; use std::ops::Deref; use std::pin::Pin; use std::sync::Arc; use std::task::{Poll, ready}; use std::time::Duration; use axum::Router; use futures::future::Either; use futures::{FutureExt, TryFutureExt}; use http::Extensions; use hyper_util::rt::{TokioIo, TokioTimer}; use tokio::net::TcpListener; use tokio::sync::oneshot; use visit_rs::{Visit, VisitFields, Visitor}; use crate::net::static_server::{UiContext, ui_router}; use crate::prelude::*; use crate::util::actor::background::BackgroundJobQueue; use crate::util::future::NonDetachingJoinHandle; use crate::util::io::ReadWriter; use crate::util::sync::{SyncRwLock, Watch}; pub type AcceptStream = Pin>; pub trait MetadataVisitor: Visitor { fn visit(&mut self, metadata: &M) -> Self::Result; } pub struct ExtensionVisitor<'a>(&'a mut Extensions); impl<'a> Visitor for ExtensionVisitor<'a> { type Result = (); } impl<'a> MetadataVisitor for ExtensionVisitor<'a> { fn visit(&mut self, metadata: &M) -> Self::Result { self.0.insert(metadata.clone()); } } impl<'a> Visit> for Box Visit> + Send + Sync + 'static> { fn visit( &self, visitor: &mut ExtensionVisitor<'a>, ) -> as Visitor>::Result { (&**self).visit(visitor) } } pub struct ExtractVisitor(Option); impl Visitor for ExtractVisitor { type Result = (); } impl MetadataVisitor for ExtractVisitor { fn visit(&mut self, metadata: &M) -> Self::Result { if let Some(matching) = (metadata as &dyn Any).downcast_ref::() { self.0 = Some(matching.clone()); } } } pub fn extract< T: Clone + Send + Sync + 'static, M: Visit> + Clone + Send + Sync + 'static, >( metadata: &M, ) -> Option { let mut visitor = ExtractVisitor(None); metadata.visit(&mut visitor); visitor.0 } #[derive(Clone, Copy, Debug)] pub struct TcpMetadata { pub peer_addr: SocketAddr, pub local_addr: SocketAddr, } impl Visit for TcpMetadata { fn visit(&self, visitor: &mut V) -> ::Result { visitor.visit(self) } } pub trait Accept { type Metadata: fmt::Debug; fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>; fn into_dyn(self) -> DynAccept where Self: Sized + Send + Sync + 'static, for<'a> Self::Metadata: Visit> + Send + Sync + 'static, { DynAccept::new(self) } } impl Accept for TcpListener { type Metadata = TcpMetadata; fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll> { if let Poll::Ready((stream, peer_addr)) = TcpListener::poll_accept(self, cx)? { if let Err(e) = socket2::SockRef::from(&stream).set_tcp_keepalive( &socket2::TcpKeepalive::new() .with_time(Duration::from_secs(900)) .with_interval(Duration::from_secs(60)) .with_retries(5), ) { tracing::error!("Failed to set tcp keepalive: {e}"); tracing::debug!("{e:?}"); } return Poll::Ready(Ok(( TcpMetadata { local_addr: self.local_addr()?, peer_addr, }, Box::pin(stream), ))); } Poll::Pending } } impl Accept for Vec where A: Accept, { type Metadata = A::Metadata; fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll> { for listener in self { if let Poll::Ready(accepted) = listener.poll_accept(cx)? { return Poll::Ready(Ok(accepted)); } } Poll::Pending } } #[derive(Debug, Clone, VisitFields)] pub struct MapListenerMetadata { pub inner: M, pub key: K, } impl Visit for MapListenerMetadata where V: MetadataVisitor, K: Visit, M: Visit, { fn visit(&self, visitor: &mut V) -> ::Result { self.visit_fields(visitor).collect() } } impl Accept for BTreeMap where K: Clone + fmt::Debug, A: Accept, { type Metadata = MapListenerMetadata; fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll> { for (key, listener) in self { if let Poll::Ready((metadata, stream)) = listener.poll_accept(cx)? { return Poll::Ready(Ok(( MapListenerMetadata { inner: metadata, key: key.clone(), }, stream, ))); } } Poll::Pending } } impl Accept for Either where A: Accept, B: Accept, { type Metadata = A::Metadata; fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll> { match self { Either::Left(a) => a.poll_accept(cx), Either::Right(b) => b.poll_accept(cx), } } } impl Accept for Option { type Metadata = A::Metadata; fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll> { match self { None => Poll::Pending, Some(a) => a.poll_accept(cx), } } } trait DynAcceptT: Send + Sync { fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>; } impl DynAcceptT for A where A: Accept + Send + Sync, ::Metadata: DynMetadataT + 'static, { fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll> { let (metadata, stream) = ready!(Accept::poll_accept(self, cx)?); Poll::Ready(Ok((DynMetadata(Box::new(metadata)), stream))) } } pub struct DynAccept(Box); trait DynMetadataT: for<'a> Visit> + fmt::Debug + Send + Sync {} impl DynMetadataT for T where for<'a> T: Visit> + fmt::Debug + Send + Sync {} #[derive(Debug)] pub struct DynMetadata(Box); impl<'a> Visit> for DynMetadata { fn visit( &self, visitor: &mut ExtensionVisitor<'a>, ) -> as Visitor>::Result { self.0.visit(visitor) } } impl Accept for DynAccept { type Metadata = DynMetadata; fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll> { DynAcceptT::poll_accept(&mut *self.0, cx) } fn into_dyn(self) -> DynAccept where Self: Sized, for<'a> Self::Metadata: Visit> + Send + Sync + 'static, { self } } impl DynAccept { pub fn new(accept: A) -> Self where A: Accept + Send + Sync + 'static, for<'a> ::Metadata: Visit> + Send + Sync + 'static, { Self(Box::new(accept)) } } #[pin_project::pin_project] pub struct Acceptor { acceptor: Watch, } impl Acceptor { pub fn new(acceptor: A) -> Self { Self { acceptor: Watch::new(acceptor), } } fn poll_changed(&mut self, cx: &mut std::task::Context<'_>) -> Poll<()> { self.acceptor.poll_changed(cx) } fn poll_accept( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll> { while self.poll_changed(cx).is_ready() {} self.acceptor.peek_mut(|a| a.poll_accept(cx)) } async fn accept(&mut self) -> Result<(A::Metadata, AcceptStream), Error> { std::future::poll_fn(|cx| self.poll_accept(cx)).await } } impl Acceptor> { pub async fn bind(listen: impl IntoIterator) -> Result { Ok(Self::new( futures::future::try_join_all(listen.into_iter().map(TcpListener::bind)).await?, )) } } impl Acceptor> { pub async fn bind_dyn(listen: impl IntoIterator) -> Result { Ok(Self::new( futures::future::try_join_all( listen .into_iter() .map(TcpListener::bind) .map(|f| f.map_ok(DynAccept::new)), ) .await?, )) } } impl Acceptor> where K: Ord + Clone + fmt::Debug + Send + Sync + 'static, { pub async fn bind_map( listen: impl IntoIterator, ) -> Result { Ok(Self::new( futures::future::try_join_all(listen.into_iter().map(|(key, addr)| async move { Ok::<_, Error>(( key, TcpListener::bind(addr) .await .with_kind(ErrorKind::Network)?, )) })) .await? .into_iter() .collect(), )) } } impl Acceptor> where K: Ord + Clone + fmt::Debug + Send + Sync + 'static, { pub async fn bind_map_dyn( listen: impl IntoIterator, ) -> Result { Ok(Self::new( futures::future::try_join_all(listen.into_iter().map(|(key, addr)| async move { Ok::<_, Error>(( key, TcpListener::bind(addr) .await .with_kind(ErrorKind::Network)?, )) })) .await? .into_iter() .map(|(key, listener)| (key, listener.into_dyn())) .collect(), )) } } pub struct WebServerAcceptorSetter { acceptor: Watch, } impl WebServerAcceptorSetter>> where A: Accept, B: Accept, { pub fn try_upgrade Result>(&self, f: F) -> Result<(), Error> { let mut res = Ok(()); self.acceptor.send_modify(|a| { *a = match a.take() { Some(Either::Left(a)) => match f(a) { Ok(b) => Some(Either::Right(b)), Err(e) => { res = Err(e); None } }, x => x, } }); res } } impl Deref for WebServerAcceptorSetter { type Target = Watch; fn deref(&self) -> &Self::Target { &self.acceptor } } pub struct WebServer { shutdown: oneshot::Sender<()>, router: Watch, acceptor: Watch, thread: NonDetachingJoinHandle<()>, } impl WebServer where A: Accept + Send + Sync + 'static, for<'a> A::Metadata: Visit> + Send + Sync + 'static, { pub fn acceptor_setter(&self) -> WebServerAcceptorSetter { WebServerAcceptorSetter { acceptor: self.acceptor.clone(), } } pub fn new(mut acceptor: Acceptor, router: Router) -> Self { let acceptor_send = acceptor.acceptor.clone(); let router = Watch::new(router); let service = router.clone_unseen(); let (shutdown, shutdown_recv) = oneshot::channel(); let thread = NonDetachingJoinHandle::from(tokio::spawn(async move { #[derive(Clone)] struct QueueRunner { queue: Arc>>, } impl hyper::rt::Executor for QueueRunner where Fut: Future + Send + 'static, { fn execute(&self, fut: Fut) { self.queue.peek(|q| { if let Some(q) = q { q.add_job(fut); } else { tracing::warn!("job queued after shutdown"); } }) } } struct SwappableRouter { router: Watch, metadata: M, } impl Visit> + Send + Sync + 'static> hyper::service::Service> for SwappableRouter { type Response = , >>::Response; type Error = , >>::Error; type Future = , >>::Future; fn call(&self, mut req: hyper::Request) -> Self::Future { use tower_service::Service; self.metadata .visit(&mut ExtensionVisitor(req.extensions_mut())); self.router.read().call(req) } } let queue_cell = Arc::new(SyncRwLock::new(None)); let graceful = hyper_util::server::graceful::GracefulShutdown::new(); let mut server = hyper_util::server::conn::auto::Builder::new(QueueRunner { queue: queue_cell.clone(), }); server .http1() .timer(TokioTimer::new()) .title_case_headers(true) .preserve_header_case(true) .http2() .timer(TokioTimer::new()) .enable_connect_protocol() .keep_alive_interval(Duration::from_secs(25)) .keep_alive_timeout(Duration::from_secs(300)); let (queue, mut runner) = BackgroundJobQueue::new(); queue_cell.replace(Some(queue.clone())); let handler = async { loop { let mut err = None; for _ in 0..5 { if let Err(e) = async { let (metadata, stream) = acceptor.accept().await?; queue.add_job( graceful.watch( server .serve_connection_with_upgrades( TokioIo::new(stream), SwappableRouter { router: service.clone(), metadata, }, ) .into_owned(), ), ); Ok::<_, Error>(()) } .await { err = Some(e); tokio::time::sleep(Duration::from_millis(100)).await; } else { break; } } if let Some(e) = err { tracing::error!("Error accepting HTTP connection: {e}"); tracing::debug!("{e:?}"); } } } .boxed(); tokio::select! { _ = shutdown_recv => (), _ = handler => (), _ = &mut runner => (), } drop(queue); drop(queue_cell.replace(None)); if !runner.is_empty() { tokio::time::timeout(Duration::from_secs(60), runner) .await .log_err(); } })); Self { shutdown, router, thread, acceptor: acceptor_send, } } pub async fn shutdown(self) { self.shutdown.send(()).unwrap_or_default(); self.thread.await.unwrap() } pub fn serve_router(&mut self, router: Router) { self.router.send(router) } pub fn serve_ui_for(&mut self, ctx: C) { self.serve_router(ui_router(ctx)) } }