fix ws timeouts

This commit is contained in:
Aiden McClelland
2025-12-18 14:54:19 -07:00
parent f8df692865
commit df3f79f282
19 changed files with 246 additions and 78 deletions

View File

@@ -28,7 +28,6 @@ use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations};
use crate::setup::SetupProgress; use crate::setup::SetupProgress;
use crate::shutdown::Shutdown; use crate::shutdown::Shutdown;
use crate::util::future::NonDetachingJoinHandle; use crate::util::future::NonDetachingJoinHandle;
use crate::util::net::WebSocketExt;
lazy_static::lazy_static! { lazy_static::lazy_static! {
pub static ref CURRENT_SECRET: Jwk = Jwk::generate_ec_key(josekit::jwk::alg::ec::EcCurve::P256).unwrap_or_else(|e| { pub static ref CURRENT_SECRET: Jwk = Jwk::generate_ec_key(josekit::jwk::alg::ec::EcCurve::P256).unwrap_or_else(|e| {

View File

@@ -22,7 +22,6 @@ use ts_rs::TS;
use crate::context::{CliContext, RpcContext}; use crate::context::{CliContext, RpcContext};
use crate::prelude::*; use crate::prelude::*;
use crate::rpc_continuations::{Guid, RpcContinuation}; use crate::rpc_continuations::{Guid, RpcContinuation};
use crate::util::net::WebSocketExt;
use crate::util::serde::{HandlerExtSerde, apply_expr}; use crate::util::serde::{HandlerExtSerde, apply_expr};
lazy_static::lazy_static! { lazy_static::lazy_static! {

View File

@@ -36,7 +36,6 @@ use crate::ssh::SSH_DIR;
use crate::system::{get_mem_info, sync_kiosk}; use crate::system::{get_mem_info, sync_kiosk};
use crate::util::io::{IOHook, open_file}; use crate::util::io::{IOHook, open_file};
use crate::util::lshw::lshw; use crate::util::lshw::lshw;
use crate::util::net::WebSocketExt;
use crate::util::{Invoke, cpupower}; use crate::util::{Invoke, cpupower};
use crate::{Error, MAIN_DATA, PACKAGE_DATA, ResultExt}; use crate::{Error, MAIN_DATA, PACKAGE_DATA, ResultExt};

View File

@@ -31,7 +31,6 @@ use crate::s9pk::manifest::PackageId;
use crate::s9pk::v2::SIG_CONTEXT; use crate::s9pk::v2::SIG_CONTEXT;
use crate::upload::upload; use crate::upload::upload;
use crate::util::io::open_file; use crate::util::io::open_file;
use crate::util::net::WebSocketExt;
use crate::util::tui::choose; use crate::util::tui::choose;
use crate::util::{FromStrParser, Never, VersionString}; use crate::util::{FromStrParser, Never, VersionString};

View File

@@ -5,13 +5,14 @@ use std::process::Stdio;
use std::str::FromStr; use std::str::FromStr;
use std::time::{Duration, UNIX_EPOCH}; use std::time::{Duration, UNIX_EPOCH};
use axum::extract::ws::{self, WebSocket}; use axum::extract::ws;
use crate::util::net::WebSocket;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use clap::builder::ValueParserFactory; use clap::builder::ValueParserFactory;
use clap::{Args, FromArgMatches, Parser}; use clap::{Args, FromArgMatches, Parser};
use color_eyre::eyre::eyre; use color_eyre::eyre::eyre;
use futures::stream::BoxStream; use futures::stream::BoxStream;
use futures::{Future, FutureExt, Stream, StreamExt, TryStreamExt}; use futures::{Future, Stream, StreamExt, TryStreamExt};
use itertools::Itertools; use itertools::Itertools;
use rpc_toolkit::yajrc::RpcError; use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{ use rpc_toolkit::{
@@ -30,7 +31,6 @@ use crate::context::{CliContext, RpcContext};
use crate::error::ResultExt; use crate::error::ResultExt;
use crate::prelude::*; use crate::prelude::*;
use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations}; use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations};
use crate::util::net::WebSocketExt;
use crate::util::serde::Reversible; use crate::util::serde::Reversible;
use crate::util::{FromStrParser, Invoke}; use crate::util::{FromStrParser, Invoke};
@@ -100,8 +100,8 @@ async fn ws_handler(
return stream.normal_close("complete").await; return stream.normal_close("complete").await;
} }
}, },
msg = stream.try_next() => { msg = stream.recv() => {
if msg.with_kind(crate::ErrorKind::Network)?.is_none() { if msg.transpose().with_kind(crate::ErrorKind::Network)?.is_none() {
return Ok(()) return Ok(())
} }
} }
@@ -698,16 +698,11 @@ pub async fn follow_logs<Context: AsRef<RpcContinuations>>(
.add( .add(
guid.clone(), guid.clone(),
RpcContinuation::ws( RpcContinuation::ws(
Box::new(move |socket| { move |socket| async move {
ws_handler(first_entry, stream, socket) if let Err(e) = ws_handler(first_entry, stream, socket).await {
.map(|x| match x { tracing::error!("Error in log stream: {}", e);
Ok(_) => (), }
Err(e) => { },
tracing::error!("Error in log stream: {}", e);
}
})
.boxed()
}),
Duration::from_secs(30), Duration::from_secs(30),
), ),
) )

View File

@@ -427,7 +427,7 @@ pub async fn connect(ctx: &RpcContext, container: &LxcContainer) -> Result<Guid,
|mut ws| async move { |mut ws| async move {
if let Err(e) = async { if let Err(e) = async {
loop { loop {
match ws.next().await { match ws.recv().await {
None => break, None => break,
Some(Ok(Message::Text(txt))) => { Some(Ok(Message::Text(txt))) => {
let mut id = None; let mut id = None;

View File

@@ -1,5 +1,6 @@
use std::net::IpAddr; use std::net::IpAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use futures::FutureExt; use futures::FutureExt;
use http::HeaderValue; use http::HeaderValue;
@@ -90,11 +91,15 @@ where
{ {
let (client, to) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) let (client, to) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new()) .timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(25))
.keep_alive_timeout(Duration::from_secs(300))
.handshake(TokioIo::new(to)) .handshake(TokioIo::new(to))
.await?; .await?;
let from = hyper::server::conn::http2::Builder::new(TokioExecutor::new()) let from = hyper::server::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new()) .timer(TokioTimer::new())
.enable_connect_protocol() .enable_connect_protocol()
.keep_alive_interval(Duration::from_secs(25)) // Add this
.keep_alive_timeout(Duration::from_secs(300))
.serve_connection( .serve_connection(
TokioIo::new(from), TokioIo::new(from),
service_fn(|mut req| { service_fn(|mut req| {

View File

@@ -100,6 +100,15 @@ impl SocksController {
TcpStream::connect(addr).await TcpStream::connect(addr).await
} }
} { } {
if let Err(e) =
socket2::SockRef::from(&target)
.set_keepalive(true)
{
tracing::error!(
"Failed to set tcp keepalive: {e}"
);
tracing::debug!("{e:?}");
}
let mut sock = reply let mut sock = reply
.reply( .reply(
Reply::Succeeded, Reply::Succeeded,

View File

@@ -680,11 +680,14 @@ impl TorController {
}) })
}) })
}) { }) {
Ok(Box::new( let tcp_stream = TcpStream::connect(target)
TcpStream::connect(target) .await
.await .with_kind(ErrorKind::Network)?;
.with_kind(ErrorKind::Network)?, if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) {
)) tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
Ok(Box::new(tcp_stream))
} else { } else {
let mut client = self.0.client.clone(); let mut client = self.0.client.clone();
client client
@@ -808,6 +811,10 @@ impl OnionService {
TcpStream::connect(target) TcpStream::connect(target)
.await .await
.with_kind(ErrorKind::Network)?; .with_kind(ErrorKind::Network)?;
if let Err(e) = socket2::SockRef::from(&outgoing).set_keepalive(true) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
let mut incoming = req let mut incoming = req
.accept(Connected::new_empty()) .accept(Connected::new_empty())
.await .await

View File

@@ -499,15 +499,22 @@ impl TorController {
}) })
}) { }) {
tracing::debug!("Resolving {addr} internally to {target}"); tracing::debug!("Resolving {addr} internally to {target}");
Ok(Box::new( let tcp_stream = TcpStream::connect(target)
TcpStream::connect(target) .await
.await .with_kind(ErrorKind::Network)?;
.with_kind(ErrorKind::Network)?, if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) {
)) tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
Ok(Box::new(tcp_stream))
} else { } else {
let mut stream = TcpStream::connect(TOR_SOCKS) let mut stream = TcpStream::connect(TOR_SOCKS)
.await .await
.with_kind(ErrorKind::Tor)?; .with_kind(ErrorKind::Tor)?;
if let Err(e) = socket2::SockRef::from(&stream).set_keepalive(true) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
socks5_impl::client::connect(&mut stream, (addr.to_string(), port), None) socks5_impl::client::connect(&mut stream, (addr.to_string(), port), None)
.await .await
.with_kind(ErrorKind::Tor)?; .with_kind(ErrorKind::Tor)?;
@@ -665,6 +672,11 @@ async fn torctl(
})?; })?;
tracing::info!("Tor is started"); tracing::info!("Tor is started");
if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
let mut conn = torut::control::UnauthenticatedConn::new(tcp_stream); let mut conn = torut::control::UnauthenticatedConn::new(tcp_stream);
let auth = conn let auth = conn
.load_protocol_info() .load_protocol_info()

View File

@@ -4,6 +4,7 @@ use std::fmt;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
use std::task::{Poll, ready}; use std::task::{Poll, ready};
use std::time::Duration;
use async_acme::acme::ACME_TLS_ALPN_NAME; use async_acme::acme::ACME_TLS_ALPN_NAME;
use color_eyre::eyre::eyre; use color_eyre::eyre::eyre;
@@ -359,6 +360,10 @@ where
.await .await
.with_ctx(|_| (ErrorKind::Network, self.addr)) .with_ctx(|_| (ErrorKind::Network, self.addr))
.log_err()?; .log_err()?;
if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
match &self.connect_ssl { match &self.connect_ssl {
Ok(client_cfg) => { Ok(client_cfg) => {
let mut client_cfg = (&**client_cfg).clone(); let mut client_cfg = (&**client_cfg).clone();

View File

@@ -106,12 +106,7 @@ impl Accept for TcpListener {
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> { ) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
if let Poll::Ready((stream, peer_addr)) = TcpListener::poll_accept(self, cx)? { if let Poll::Ready((stream, peer_addr)) = TcpListener::poll_accept(self, cx)? {
if let Err(e) = socket2::SockRef::from(&stream).set_tcp_keepalive( if let Err(e) = socket2::SockRef::from(&stream).set_keepalive(true) {
&socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(900))
.with_interval(Duration::from_secs(60))
.with_retries(5),
) {
tracing::error!("Failed to set tcp keepalive: {e}"); tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}"); tracing::debug!("{e:?}");
} }

View File

@@ -6,7 +6,7 @@ use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use axum::extract::Request; use axum::extract::Request;
use axum::extract::ws::WebSocket; use axum::extract::ws::WebSocket as AxumWebSocket;
use axum::response::Response; use axum::response::Response;
use clap::builder::ValueParserFactory; use clap::builder::ValueParserFactory;
use futures::future::BoxFuture; use futures::future::BoxFuture;
@@ -18,6 +18,7 @@ use ts_rs::TS;
#[allow(unused_imports)] #[allow(unused_imports)]
use crate::prelude::*; use crate::prelude::*;
use crate::util::future::TimedResource; use crate::util::future::TimedResource;
use crate::util::net::WebSocket;
use crate::util::{FromStrParser, new_guid}; use crate::util::{FromStrParser, new_guid};
#[derive( #[derive(
@@ -109,7 +110,7 @@ impl Future for WebSocketFuture {
} }
} }
} }
pub type WebSocketHandler = Box<dyn FnOnce(WebSocket) -> WebSocketFuture + Send>; pub type WebSocketHandler = Box<dyn FnOnce(AxumWebSocket) -> WebSocketFuture + Send>;
pub enum RpcContinuation { pub enum RpcContinuation {
Rest(TimedResource<RestHandler>), Rest(TimedResource<RestHandler>),
@@ -137,7 +138,7 @@ impl RpcContinuation {
RpcContinuation::WebSocket(TimedResource::new( RpcContinuation::WebSocket(TimedResource::new(
Box::new(|ws| WebSocketFuture { Box::new(|ws| WebSocketFuture {
kill: None, kill: None,
fut: handler(ws).boxed(), fut: handler(ws.into()).boxed(),
}), }),
timeout, timeout,
)) ))
@@ -169,7 +170,7 @@ impl RpcContinuation {
RpcContinuation::WebSocket(TimedResource::new( RpcContinuation::WebSocket(TimedResource::new(
Box::new(|ws| WebSocketFuture { Box::new(|ws| WebSocketFuture {
kill, kill,
fut: handler(ws).boxed(), fut: handler(ws.into()).boxed(),
}), }),
timeout, timeout,
)) ))

View File

@@ -83,9 +83,10 @@ impl procfs::FromBufRead for NSPid {
fn from_buf_read<R: std::io::BufRead>(r: R) -> procfs::ProcResult<Self> { fn from_buf_read<R: std::io::BufRead>(r: R) -> procfs::ProcResult<Self> {
for line in r.lines() { for line in r.lines() {
let line = line?; let line = line?;
if let Some(row) = line.trim().strip_prefix("NSpid") { if let Some(row) = line.trim().strip_prefix("NSpid:") {
return Ok(Self( return Ok(Self(
row.split_ascii_whitespace() row.trim()
.split_ascii_whitespace()
.map(|pid| pid.parse::<i32>()) .map(|pid| pid.parse::<i32>())
.collect::<Result<Vec<_>, _>>()?, .collect::<Result<Vec<_>, _>>()?,
)); ));
@@ -205,9 +206,9 @@ impl ExecParams {
split.next() split.next()
}) })
}; };
std::os::unix::fs::chown("/proc/self/fd/0", Some(uid), Some(gid)).log_err(); std::os::unix::fs::chown("/proc/self/fd/0", Some(uid), Some(gid)).ok();
std::os::unix::fs::chown("/proc/self/fd/1", Some(uid), Some(gid)).log_err(); std::os::unix::fs::chown("/proc/self/fd/1", Some(uid), Some(gid)).ok();
std::os::unix::fs::chown("/proc/self/fd/2", Some(uid), Some(gid)).log_err(); std::os::unix::fs::chown("/proc/self/fd/2", Some(uid), Some(gid)).ok();
cmd.uid(uid); cmd.uid(uid);
cmd.gid(gid); cmd.gid(gid);
} else { } else {
@@ -290,8 +291,6 @@ pub fn launch(
None None
}; };
let pty_size = pty_size.or_else(|| TermSize::get_current());
let (stdin_send, stdin_recv) = oneshot::channel::<Box<dyn Write + Send>>(); let (stdin_send, stdin_recv) = oneshot::channel::<Box<dyn Write + Send>>();
std::thread::spawn(move || { std::thread::spawn(move || {
if let Ok(mut cstdin) = stdin_recv.blocking_recv() { if let Ok(mut cstdin) = stdin_recv.blocking_recv() {
@@ -370,7 +369,7 @@ pub fn launch(
.map_err(color_eyre::eyre::Report::msg) .map_err(color_eyre::eyre::Report::msg)
.with_ctx(|_| (ErrorKind::Filesystem, "spawning child process"))?; .with_ctx(|_| (ErrorKind::Filesystem, "spawning child process"))?;
send_pid.send(child.id() as i32).unwrap_or_default(); send_pid.send(child.id() as i32).unwrap_or_default();
if let Some(pty_size) = pty_size { if let Some(pty_size) = pty_size.or_else(|| TermSize::get_current()) {
let size = if let Some((x, y)) = pty_size.pixels { let size = if let Some((x, y)) = pty_size.pixels {
::pty_process::Size::new_with_pixel(pty_size.rows, pty_size.cols, x, y) ::pty_process::Size::new_with_pixel(pty_size.rows, pty_size.cols, x, y)
} else { } else {
@@ -541,8 +540,6 @@ pub fn exec(
None None
}; };
let pty_size = pty_size.or_else(|| TermSize::get_current());
let (stdin_send, stdin_recv) = oneshot::channel::<Box<dyn Write + Send>>(); let (stdin_send, stdin_recv) = oneshot::channel::<Box<dyn Write + Send>>();
std::thread::spawn(move || { std::thread::spawn(move || {
if let Ok(mut cstdin) = stdin_recv.blocking_recv() { if let Ok(mut cstdin) = stdin_recv.blocking_recv() {
@@ -630,7 +627,7 @@ pub fn exec(
.map_err(color_eyre::eyre::Report::msg) .map_err(color_eyre::eyre::Report::msg)
.with_ctx(|_| (ErrorKind::Filesystem, "spawning child process"))?; .with_ctx(|_| (ErrorKind::Filesystem, "spawning child process"))?;
send_pid.send(child.id() as i32).unwrap_or_default(); send_pid.send(child.id() as i32).unwrap_or_default();
if let Some(pty_size) = pty_size { if let Some(pty_size) = pty_size.or_else(|| TermSize::get_current()) {
let size = if let Some((x, y)) = pty_size.pixels { let size = if let Some((x, y)) = pty_size.pixels {
::pty_process::Size::new_with_pixel(pty_size.rows, pty_size.cols, x, y) ::pty_process::Size::new_with_pixel(pty_size.rows, pty_size.cols, x, y)
} else { } else {

View File

@@ -8,7 +8,8 @@ use std::process::Stdio;
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
use std::time::Duration; use std::time::Duration;
use axum::extract::ws::{Utf8Bytes, WebSocket}; use axum::extract::ws::Utf8Bytes;
use crate::util::net::WebSocket;
use clap::Parser; use clap::Parser;
use futures::future::BoxFuture; use futures::future::BoxFuture;
use futures::stream::FusedStream; use futures::stream::FusedStream;
@@ -47,7 +48,6 @@ use crate::util::Never;
use crate::util::actor::concurrent::ConcurrentActor; use crate::util::actor::concurrent::ConcurrentActor;
use crate::util::future::NonDetachingJoinHandle; use crate::util::future::NonDetachingJoinHandle;
use crate::util::io::{AsyncReadStream, AtomicFile, TermSize, delete_file}; use crate::util::io::{AsyncReadStream, AtomicFile, TermSize, delete_file};
use crate::util::net::WebSocketExt;
use crate::util::serde::Pem; use crate::util::serde::Pem;
use crate::util::sync::SyncMutex; use crate::util::sync::SyncMutex;
use crate::volume::data_dir; use crate::volume::data_dir;

View File

@@ -5,7 +5,7 @@ use std::time::Duration;
use chrono::Utc; use chrono::Utc;
use clap::Parser; use clap::Parser;
use color_eyre::eyre::eyre; use color_eyre::eyre::eyre;
use futures::{FutureExt, TryStreamExt}; use futures::FutureExt;
use imbl::vector; use imbl::vector;
use imbl_value::InternedString; use imbl_value::InternedString;
use rpc_toolkit::{Context, Empty, HandlerExt, ParentHandler, from_fn_async}; use rpc_toolkit::{Context, Empty, HandlerExt, ParentHandler, from_fn_async};
@@ -24,7 +24,6 @@ use crate::shutdown::Shutdown;
use crate::util::Invoke; use crate::util::Invoke;
use crate::util::cpupower::{Governor, get_available_governors, set_governor}; use crate::util::cpupower::{Governor, get_available_governors, set_governor};
use crate::util::io::open_file; use crate::util::io::open_file;
use crate::util::net::WebSocketExt;
use crate::util::serde::{HandlerExtSerde, WithIoFormat, display_serializable}; use crate::util::serde::{HandlerExtSerde, WithIoFormat, display_serializable};
use crate::util::sync::Watch; use crate::util::sync::Watch;
use crate::{MAIN_DATA, PACKAGE_DATA}; use crate::{MAIN_DATA, PACKAGE_DATA};
@@ -527,8 +526,8 @@ pub async fn metrics_follow(
.into(), .into(),
)).await.with_kind(ErrorKind::Network)?; )).await.with_kind(ErrorKind::Network)?;
} }
msg = ws.try_next() => { msg = ws.recv() => {
if msg.with_kind(crate::ErrorKind::Network)?.is_none() { if msg.transpose().with_kind(crate::ErrorKind::Network)?.is_none() {
break; break;
} }
} }

View File

@@ -27,7 +27,6 @@ use crate::tunnel::auth::SignerInfo;
use crate::tunnel::context::TunnelContext; use crate::tunnel::context::TunnelContext;
use crate::tunnel::web::WebserverInfo; use crate::tunnel::web::WebserverInfo;
use crate::tunnel::wg::WgServer; use crate::tunnel::wg::WgServer;
use crate::util::net::WebSocketExt;
use crate::util::serde::{HandlerExtSerde, apply_expr}; use crate::util::serde::{HandlerExtSerde, apply_expr};
#[derive(Default, Deserialize, Serialize, HasModel, TS)] #[derive(Default, Deserialize, Serialize, HasModel, TS)]

View File

@@ -37,7 +37,6 @@ use crate::sound::{
use crate::util::Invoke; use crate::util::Invoke;
use crate::util::future::NonDetachingJoinHandle; use crate::util::future::NonDetachingJoinHandle;
use crate::util::io::AtomicFile; use crate::util::io::AtomicFile;
use crate::util::net::WebSocketExt;
#[derive(Deserialize, Serialize, Parser, TS)] #[derive(Deserialize, Serialize, Parser, TS)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]

View File

@@ -1,40 +1,154 @@
use core::fmt; use core::fmt;
use std::pin::Pin;
use std::sync::Mutex; use std::sync::Mutex;
use std::task::{Context, Poll, ready};
use std::time::Duration;
use axum::extract::ws::{self, CloseFrame, Utf8Bytes}; use axum::extract::ws::{CloseFrame, Message, Utf8Bytes, WebSocket as AxumWebSocket};
use futures::{Future, Stream, StreamExt}; use futures::{Sink, SinkExt, Stream, StreamExt};
use tokio::time::{Instant, Sleep};
use crate::prelude::*; use crate::prelude::*;
pub trait WebSocketExt { const PING_INTERVAL: Duration = Duration::from_secs(30);
fn normal_close( const PING_TIMEOUT: Duration = Duration::from_secs(300);
self,
msg: impl Into<Utf8Bytes> + Send, /// A wrapper around axum's WebSocket that automatically sends ping frames
) -> impl Future<Output = Result<(), Error>> + Send; /// to keep the connection alive during HTTP/2.
fn close_result( ///
self, /// HTTP/2 streams can timeout if idle, even when the underlying connection
result: Result<impl Into<Utf8Bytes> + Send, impl fmt::Display + Send>, /// has keep-alive enabled. This wrapper sends a ping frame if no data has
) -> impl Future<Output = Result<(), Error>> + Send; /// been sent within the ping interval while waiting to receive a message.
pub struct WebSocket {
inner: AxumWebSocket,
ping_state: Option<(bool, u64)>,
next_ping: Pin<Box<Sleep>>,
fused: bool,
} }
impl WebSocketExt for ws::WebSocket { impl WebSocket {
async fn normal_close(self, msg: impl Into<Utf8Bytes> + Send) -> Result<(), Error> { pub fn new(ws: AxumWebSocket) -> Self {
self.close_result(Ok::<_, Error>(msg)).await Self {
inner: ws,
ping_state: None,
next_ping: Box::pin(tokio::time::sleep(PING_INTERVAL)),
fused: false,
}
} }
async fn close_result(
pub fn into_inner(self) -> AxumWebSocket {
self.inner
}
pub async fn send(&mut self, msg: Message) -> Result<(), axum::Error> {
if self.ping_state.is_none() {
self.next_ping
.as_mut()
.reset(Instant::now() + PING_INTERVAL);
}
self.inner.send(msg).await
}
pub fn poll_recv(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Message, axum::Error>>> {
if self.fused {
return Poll::Ready(None);
}
let mut inner = Pin::new(&mut self.inner);
loop {
if let Poll::Ready(msg) = inner.as_mut().poll_next(cx) {
if self.ping_state.is_none() {
self.next_ping
.as_mut()
.reset(Instant::now() + PING_INTERVAL);
}
if let Some(Ok(Message::Pong(x))) = &msg {
if let Some((true, id)) = self.ping_state {
if &u64::to_be_bytes(id)[..] == &**x {
self.ping_state.take();
self.next_ping
.as_mut()
.reset(Instant::now() + PING_INTERVAL);
continue;
}
}
}
break Poll::Ready(msg);
}
if let Some((sent, id)) = &mut self.ping_state {
if !*sent {
ready!(inner.as_mut().poll_ready(cx))?;
inner
.as_mut()
.start_send(Message::Ping((u64::to_be_bytes(*id).to_vec()).into()))?;
self.next_ping.as_mut().reset(Instant::now() + PING_TIMEOUT);
*sent = true;
}
ready!(inner.as_mut().poll_flush(cx))?;
}
ready!(self.next_ping.as_mut().poll(cx));
if self.ping_state.is_some() {
self.fused = true;
break Poll::Ready(Some(Err(axum::Error::new(eyre!(
"Timeout: WebSocket did not respond to ping within {PING_TIMEOUT:?}"
)))));
}
self.ping_state = Some((false, rand::random()));
}
}
/// Receive a message from the websocket, automatically sending ping frames
/// if the connection is idle for too long.
///
/// Ping and Pong frames are handled internally and not returned to the caller.
pub async fn recv(&mut self) -> Option<Result<Message, axum::Error>> {
futures::future::poll_fn(|cx| self.poll_recv(cx)).await
}
/// Close the websocket connection normally.
pub async fn normal_close(mut self, msg: impl Into<Utf8Bytes>) -> Result<(), Error> {
self.inner
.send(Message::Close(Some(CloseFrame {
code: 1000,
reason: msg.into(),
})))
.await
.with_kind(ErrorKind::Network)?;
while !matches!(
self.inner
.recv()
.await
.transpose()
.with_kind(ErrorKind::Network)?,
Some(Message::Close(_)) | None
) {}
Ok(())
}
/// Close the websocket connection with a result.
pub async fn close_result(
mut self, mut self,
result: Result<impl Into<Utf8Bytes> + Send, impl fmt::Display + Send>, result: Result<impl Into<Utf8Bytes> + Send, impl fmt::Display + Send>,
) -> Result<(), Error> { ) -> Result<(), Error> {
match result { match result {
Ok(msg) => self Ok(msg) => self
.send(ws::Message::Close(Some(CloseFrame { .inner
.send(Message::Close(Some(CloseFrame {
code: 1000, code: 1000,
reason: msg.into(), reason: msg.into(),
}))) })))
.await .await
.with_kind(ErrorKind::Network)?, .with_kind(ErrorKind::Network)?,
Err(e) => self Err(e) => self
.send(ws::Message::Close(Some(CloseFrame { .inner
.send(Message::Close(Some(CloseFrame {
code: 1011, code: 1011,
reason: e.to_string().into(), reason: e.to_string().into(),
}))) })))
@@ -42,16 +156,51 @@ impl WebSocketExt for ws::WebSocket {
.with_kind(ErrorKind::Network)?, .with_kind(ErrorKind::Network)?,
} }
while !matches!( while !matches!(
self.recv() self.inner
.recv()
.await .await
.transpose() .transpose()
.with_kind(ErrorKind::Network)?, .with_kind(ErrorKind::Network)?,
Some(ws::Message::Close(_)) | None Some(Message::Close(_)) | None
) {} ) {}
Ok(()) Ok(())
} }
} }
impl From<AxumWebSocket> for WebSocket {
fn from(ws: AxumWebSocket) -> Self {
Self::new(ws)
}
}
impl Stream for WebSocket {
type Item = Result<Message, axum::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.get_mut().poll_recv(cx)
}
}
impl Sink<Message> for WebSocket {
type Error = axum::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.get_mut().inner.poll_ready_unpin(cx)
}
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
self.get_mut().inner.start_send_unpin(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.get_mut().inner.poll_flush_unpin(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.get_mut().inner.poll_close_unpin(cx)
}
}
pub struct SyncBody(Mutex<axum::body::BodyDataStream>); pub struct SyncBody(Mutex<axum::body::BodyDataStream>);
impl From<axum::body::Body> for SyncBody { impl From<axum::body::Body> for SyncBody {
fn from(value: axum::body::Body) -> Self { fn from(value: axum::body::Body) -> Self {