fix ws timeouts

This commit is contained in:
Aiden McClelland
2025-12-18 14:54:19 -07:00
parent f8df692865
commit df3f79f282
19 changed files with 246 additions and 78 deletions

View File

@@ -28,7 +28,6 @@ use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations};
use crate::setup::SetupProgress;
use crate::shutdown::Shutdown;
use crate::util::future::NonDetachingJoinHandle;
use crate::util::net::WebSocketExt;
lazy_static::lazy_static! {
pub static ref CURRENT_SECRET: Jwk = Jwk::generate_ec_key(josekit::jwk::alg::ec::EcCurve::P256).unwrap_or_else(|e| {

View File

@@ -22,7 +22,6 @@ use ts_rs::TS;
use crate::context::{CliContext, RpcContext};
use crate::prelude::*;
use crate::rpc_continuations::{Guid, RpcContinuation};
use crate::util::net::WebSocketExt;
use crate::util::serde::{HandlerExtSerde, apply_expr};
lazy_static::lazy_static! {

View File

@@ -36,7 +36,6 @@ use crate::ssh::SSH_DIR;
use crate::system::{get_mem_info, sync_kiosk};
use crate::util::io::{IOHook, open_file};
use crate::util::lshw::lshw;
use crate::util::net::WebSocketExt;
use crate::util::{Invoke, cpupower};
use crate::{Error, MAIN_DATA, PACKAGE_DATA, ResultExt};

View File

@@ -31,7 +31,6 @@ use crate::s9pk::manifest::PackageId;
use crate::s9pk::v2::SIG_CONTEXT;
use crate::upload::upload;
use crate::util::io::open_file;
use crate::util::net::WebSocketExt;
use crate::util::tui::choose;
use crate::util::{FromStrParser, Never, VersionString};

View File

@@ -5,13 +5,14 @@ use std::process::Stdio;
use std::str::FromStr;
use std::time::{Duration, UNIX_EPOCH};
use axum::extract::ws::{self, WebSocket};
use axum::extract::ws;
use crate::util::net::WebSocket;
use chrono::{DateTime, Utc};
use clap::builder::ValueParserFactory;
use clap::{Args, FromArgMatches, Parser};
use color_eyre::eyre::eyre;
use futures::stream::BoxStream;
use futures::{Future, FutureExt, Stream, StreamExt, TryStreamExt};
use futures::{Future, Stream, StreamExt, TryStreamExt};
use itertools::Itertools;
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{
@@ -30,7 +31,6 @@ use crate::context::{CliContext, RpcContext};
use crate::error::ResultExt;
use crate::prelude::*;
use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations};
use crate::util::net::WebSocketExt;
use crate::util::serde::Reversible;
use crate::util::{FromStrParser, Invoke};
@@ -100,8 +100,8 @@ async fn ws_handler(
return stream.normal_close("complete").await;
}
},
msg = stream.try_next() => {
if msg.with_kind(crate::ErrorKind::Network)?.is_none() {
msg = stream.recv() => {
if msg.transpose().with_kind(crate::ErrorKind::Network)?.is_none() {
return Ok(())
}
}
@@ -698,16 +698,11 @@ pub async fn follow_logs<Context: AsRef<RpcContinuations>>(
.add(
guid.clone(),
RpcContinuation::ws(
Box::new(move |socket| {
ws_handler(first_entry, stream, socket)
.map(|x| match x {
Ok(_) => (),
Err(e) => {
tracing::error!("Error in log stream: {}", e);
}
})
.boxed()
}),
move |socket| async move {
if let Err(e) = ws_handler(first_entry, stream, socket).await {
tracing::error!("Error in log stream: {}", e);
}
},
Duration::from_secs(30),
),
)

View File

@@ -427,7 +427,7 @@ pub async fn connect(ctx: &RpcContext, container: &LxcContainer) -> Result<Guid,
|mut ws| async move {
if let Err(e) = async {
loop {
match ws.next().await {
match ws.recv().await {
None => break,
Some(Ok(Message::Text(txt))) => {
let mut id = None;

View File

@@ -1,5 +1,6 @@
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use futures::FutureExt;
use http::HeaderValue;
@@ -90,11 +91,15 @@ where
{
let (client, to) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(25))
.keep_alive_timeout(Duration::from_secs(300))
.handshake(TokioIo::new(to))
.await?;
let from = hyper::server::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.enable_connect_protocol()
.keep_alive_interval(Duration::from_secs(25)) // Add this
.keep_alive_timeout(Duration::from_secs(300))
.serve_connection(
TokioIo::new(from),
service_fn(|mut req| {

View File

@@ -100,6 +100,15 @@ impl SocksController {
TcpStream::connect(addr).await
}
} {
if let Err(e) =
socket2::SockRef::from(&target)
.set_keepalive(true)
{
tracing::error!(
"Failed to set tcp keepalive: {e}"
);
tracing::debug!("{e:?}");
}
let mut sock = reply
.reply(
Reply::Succeeded,

View File

@@ -680,11 +680,14 @@ impl TorController {
})
})
}) {
Ok(Box::new(
TcpStream::connect(target)
.await
.with_kind(ErrorKind::Network)?,
))
let tcp_stream = TcpStream::connect(target)
.await
.with_kind(ErrorKind::Network)?;
if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
Ok(Box::new(tcp_stream))
} else {
let mut client = self.0.client.clone();
client
@@ -808,6 +811,10 @@ impl OnionService {
TcpStream::connect(target)
.await
.with_kind(ErrorKind::Network)?;
if let Err(e) = socket2::SockRef::from(&outgoing).set_keepalive(true) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
let mut incoming = req
.accept(Connected::new_empty())
.await

View File

@@ -499,15 +499,22 @@ impl TorController {
})
}) {
tracing::debug!("Resolving {addr} internally to {target}");
Ok(Box::new(
TcpStream::connect(target)
.await
.with_kind(ErrorKind::Network)?,
))
let tcp_stream = TcpStream::connect(target)
.await
.with_kind(ErrorKind::Network)?;
if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
Ok(Box::new(tcp_stream))
} else {
let mut stream = TcpStream::connect(TOR_SOCKS)
.await
.with_kind(ErrorKind::Tor)?;
if let Err(e) = socket2::SockRef::from(&stream).set_keepalive(true) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
socks5_impl::client::connect(&mut stream, (addr.to_string(), port), None)
.await
.with_kind(ErrorKind::Tor)?;
@@ -665,6 +672,11 @@ async fn torctl(
})?;
tracing::info!("Tor is started");
if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
let mut conn = torut::control::UnauthenticatedConn::new(tcp_stream);
let auth = conn
.load_protocol_info()

View File

@@ -4,6 +4,7 @@ use std::fmt;
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, Weak};
use std::task::{Poll, ready};
use std::time::Duration;
use async_acme::acme::ACME_TLS_ALPN_NAME;
use color_eyre::eyre::eyre;
@@ -359,6 +360,10 @@ where
.await
.with_ctx(|_| (ErrorKind::Network, self.addr))
.log_err()?;
if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}
match &self.connect_ssl {
Ok(client_cfg) => {
let mut client_cfg = (&**client_cfg).clone();

View File

@@ -106,12 +106,7 @@ impl Accept for TcpListener {
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
if let Poll::Ready((stream, peer_addr)) = TcpListener::poll_accept(self, cx)? {
if let Err(e) = socket2::SockRef::from(&stream).set_tcp_keepalive(
&socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(900))
.with_interval(Duration::from_secs(60))
.with_retries(5),
) {
if let Err(e) = socket2::SockRef::from(&stream).set_keepalive(true) {
tracing::error!("Failed to set tcp keepalive: {e}");
tracing::debug!("{e:?}");
}

View File

@@ -6,7 +6,7 @@ use std::task::{Context, Poll};
use std::time::Duration;
use axum::extract::Request;
use axum::extract::ws::WebSocket;
use axum::extract::ws::WebSocket as AxumWebSocket;
use axum::response::Response;
use clap::builder::ValueParserFactory;
use futures::future::BoxFuture;
@@ -18,6 +18,7 @@ use ts_rs::TS;
#[allow(unused_imports)]
use crate::prelude::*;
use crate::util::future::TimedResource;
use crate::util::net::WebSocket;
use crate::util::{FromStrParser, new_guid};
#[derive(
@@ -109,7 +110,7 @@ impl Future for WebSocketFuture {
}
}
}
pub type WebSocketHandler = Box<dyn FnOnce(WebSocket) -> WebSocketFuture + Send>;
pub type WebSocketHandler = Box<dyn FnOnce(AxumWebSocket) -> WebSocketFuture + Send>;
pub enum RpcContinuation {
Rest(TimedResource<RestHandler>),
@@ -137,7 +138,7 @@ impl RpcContinuation {
RpcContinuation::WebSocket(TimedResource::new(
Box::new(|ws| WebSocketFuture {
kill: None,
fut: handler(ws).boxed(),
fut: handler(ws.into()).boxed(),
}),
timeout,
))
@@ -169,7 +170,7 @@ impl RpcContinuation {
RpcContinuation::WebSocket(TimedResource::new(
Box::new(|ws| WebSocketFuture {
kill,
fut: handler(ws).boxed(),
fut: handler(ws.into()).boxed(),
}),
timeout,
))

View File

@@ -83,9 +83,10 @@ impl procfs::FromBufRead for NSPid {
fn from_buf_read<R: std::io::BufRead>(r: R) -> procfs::ProcResult<Self> {
for line in r.lines() {
let line = line?;
if let Some(row) = line.trim().strip_prefix("NSpid") {
if let Some(row) = line.trim().strip_prefix("NSpid:") {
return Ok(Self(
row.split_ascii_whitespace()
row.trim()
.split_ascii_whitespace()
.map(|pid| pid.parse::<i32>())
.collect::<Result<Vec<_>, _>>()?,
));
@@ -205,9 +206,9 @@ impl ExecParams {
split.next()
})
};
std::os::unix::fs::chown("/proc/self/fd/0", Some(uid), Some(gid)).log_err();
std::os::unix::fs::chown("/proc/self/fd/1", Some(uid), Some(gid)).log_err();
std::os::unix::fs::chown("/proc/self/fd/2", Some(uid), Some(gid)).log_err();
std::os::unix::fs::chown("/proc/self/fd/0", Some(uid), Some(gid)).ok();
std::os::unix::fs::chown("/proc/self/fd/1", Some(uid), Some(gid)).ok();
std::os::unix::fs::chown("/proc/self/fd/2", Some(uid), Some(gid)).ok();
cmd.uid(uid);
cmd.gid(gid);
} else {
@@ -290,8 +291,6 @@ pub fn launch(
None
};
let pty_size = pty_size.or_else(|| TermSize::get_current());
let (stdin_send, stdin_recv) = oneshot::channel::<Box<dyn Write + Send>>();
std::thread::spawn(move || {
if let Ok(mut cstdin) = stdin_recv.blocking_recv() {
@@ -370,7 +369,7 @@ pub fn launch(
.map_err(color_eyre::eyre::Report::msg)
.with_ctx(|_| (ErrorKind::Filesystem, "spawning child process"))?;
send_pid.send(child.id() as i32).unwrap_or_default();
if let Some(pty_size) = pty_size {
if let Some(pty_size) = pty_size.or_else(|| TermSize::get_current()) {
let size = if let Some((x, y)) = pty_size.pixels {
::pty_process::Size::new_with_pixel(pty_size.rows, pty_size.cols, x, y)
} else {
@@ -541,8 +540,6 @@ pub fn exec(
None
};
let pty_size = pty_size.or_else(|| TermSize::get_current());
let (stdin_send, stdin_recv) = oneshot::channel::<Box<dyn Write + Send>>();
std::thread::spawn(move || {
if let Ok(mut cstdin) = stdin_recv.blocking_recv() {
@@ -630,7 +627,7 @@ pub fn exec(
.map_err(color_eyre::eyre::Report::msg)
.with_ctx(|_| (ErrorKind::Filesystem, "spawning child process"))?;
send_pid.send(child.id() as i32).unwrap_or_default();
if let Some(pty_size) = pty_size {
if let Some(pty_size) = pty_size.or_else(|| TermSize::get_current()) {
let size = if let Some((x, y)) = pty_size.pixels {
::pty_process::Size::new_with_pixel(pty_size.rows, pty_size.cols, x, y)
} else {

View File

@@ -8,7 +8,8 @@ use std::process::Stdio;
use std::sync::{Arc, Weak};
use std::time::Duration;
use axum::extract::ws::{Utf8Bytes, WebSocket};
use axum::extract::ws::Utf8Bytes;
use crate::util::net::WebSocket;
use clap::Parser;
use futures::future::BoxFuture;
use futures::stream::FusedStream;
@@ -47,7 +48,6 @@ use crate::util::Never;
use crate::util::actor::concurrent::ConcurrentActor;
use crate::util::future::NonDetachingJoinHandle;
use crate::util::io::{AsyncReadStream, AtomicFile, TermSize, delete_file};
use crate::util::net::WebSocketExt;
use crate::util::serde::Pem;
use crate::util::sync::SyncMutex;
use crate::volume::data_dir;

View File

@@ -5,7 +5,7 @@ use std::time::Duration;
use chrono::Utc;
use clap::Parser;
use color_eyre::eyre::eyre;
use futures::{FutureExt, TryStreamExt};
use futures::FutureExt;
use imbl::vector;
use imbl_value::InternedString;
use rpc_toolkit::{Context, Empty, HandlerExt, ParentHandler, from_fn_async};
@@ -24,7 +24,6 @@ use crate::shutdown::Shutdown;
use crate::util::Invoke;
use crate::util::cpupower::{Governor, get_available_governors, set_governor};
use crate::util::io::open_file;
use crate::util::net::WebSocketExt;
use crate::util::serde::{HandlerExtSerde, WithIoFormat, display_serializable};
use crate::util::sync::Watch;
use crate::{MAIN_DATA, PACKAGE_DATA};
@@ -527,8 +526,8 @@ pub async fn metrics_follow(
.into(),
)).await.with_kind(ErrorKind::Network)?;
}
msg = ws.try_next() => {
if msg.with_kind(crate::ErrorKind::Network)?.is_none() {
msg = ws.recv() => {
if msg.transpose().with_kind(crate::ErrorKind::Network)?.is_none() {
break;
}
}

View File

@@ -27,7 +27,6 @@ use crate::tunnel::auth::SignerInfo;
use crate::tunnel::context::TunnelContext;
use crate::tunnel::web::WebserverInfo;
use crate::tunnel::wg::WgServer;
use crate::util::net::WebSocketExt;
use crate::util::serde::{HandlerExtSerde, apply_expr};
#[derive(Default, Deserialize, Serialize, HasModel, TS)]

View File

@@ -37,7 +37,6 @@ use crate::sound::{
use crate::util::Invoke;
use crate::util::future::NonDetachingJoinHandle;
use crate::util::io::AtomicFile;
use crate::util::net::WebSocketExt;
#[derive(Deserialize, Serialize, Parser, TS)]
#[serde(rename_all = "camelCase")]

View File

@@ -1,40 +1,154 @@
use core::fmt;
use std::pin::Pin;
use std::sync::Mutex;
use std::task::{Context, Poll, ready};
use std::time::Duration;
use axum::extract::ws::{self, CloseFrame, Utf8Bytes};
use futures::{Future, Stream, StreamExt};
use axum::extract::ws::{CloseFrame, Message, Utf8Bytes, WebSocket as AxumWebSocket};
use futures::{Sink, SinkExt, Stream, StreamExt};
use tokio::time::{Instant, Sleep};
use crate::prelude::*;
pub trait WebSocketExt {
fn normal_close(
self,
msg: impl Into<Utf8Bytes> + Send,
) -> impl Future<Output = Result<(), Error>> + Send;
fn close_result(
self,
result: Result<impl Into<Utf8Bytes> + Send, impl fmt::Display + Send>,
) -> impl Future<Output = Result<(), Error>> + Send;
const PING_INTERVAL: Duration = Duration::from_secs(30);
const PING_TIMEOUT: Duration = Duration::from_secs(300);
/// A wrapper around axum's WebSocket that automatically sends ping frames
/// to keep the connection alive during HTTP/2.
///
/// HTTP/2 streams can timeout if idle, even when the underlying connection
/// has keep-alive enabled. This wrapper sends a ping frame if no data has
/// been sent within the ping interval while waiting to receive a message.
pub struct WebSocket {
inner: AxumWebSocket,
ping_state: Option<(bool, u64)>,
next_ping: Pin<Box<Sleep>>,
fused: bool,
}
impl WebSocketExt for ws::WebSocket {
async fn normal_close(self, msg: impl Into<Utf8Bytes> + Send) -> Result<(), Error> {
self.close_result(Ok::<_, Error>(msg)).await
impl WebSocket {
pub fn new(ws: AxumWebSocket) -> Self {
Self {
inner: ws,
ping_state: None,
next_ping: Box::pin(tokio::time::sleep(PING_INTERVAL)),
fused: false,
}
}
async fn close_result(
pub fn into_inner(self) -> AxumWebSocket {
self.inner
}
pub async fn send(&mut self, msg: Message) -> Result<(), axum::Error> {
if self.ping_state.is_none() {
self.next_ping
.as_mut()
.reset(Instant::now() + PING_INTERVAL);
}
self.inner.send(msg).await
}
pub fn poll_recv(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Message, axum::Error>>> {
if self.fused {
return Poll::Ready(None);
}
let mut inner = Pin::new(&mut self.inner);
loop {
if let Poll::Ready(msg) = inner.as_mut().poll_next(cx) {
if self.ping_state.is_none() {
self.next_ping
.as_mut()
.reset(Instant::now() + PING_INTERVAL);
}
if let Some(Ok(Message::Pong(x))) = &msg {
if let Some((true, id)) = self.ping_state {
if &u64::to_be_bytes(id)[..] == &**x {
self.ping_state.take();
self.next_ping
.as_mut()
.reset(Instant::now() + PING_INTERVAL);
continue;
}
}
}
break Poll::Ready(msg);
}
if let Some((sent, id)) = &mut self.ping_state {
if !*sent {
ready!(inner.as_mut().poll_ready(cx))?;
inner
.as_mut()
.start_send(Message::Ping((u64::to_be_bytes(*id).to_vec()).into()))?;
self.next_ping.as_mut().reset(Instant::now() + PING_TIMEOUT);
*sent = true;
}
ready!(inner.as_mut().poll_flush(cx))?;
}
ready!(self.next_ping.as_mut().poll(cx));
if self.ping_state.is_some() {
self.fused = true;
break Poll::Ready(Some(Err(axum::Error::new(eyre!(
"Timeout: WebSocket did not respond to ping within {PING_TIMEOUT:?}"
)))));
}
self.ping_state = Some((false, rand::random()));
}
}
/// Receive a message from the websocket, automatically sending ping frames
/// if the connection is idle for too long.
///
/// Ping and Pong frames are handled internally and not returned to the caller.
pub async fn recv(&mut self) -> Option<Result<Message, axum::Error>> {
futures::future::poll_fn(|cx| self.poll_recv(cx)).await
}
/// Close the websocket connection normally.
pub async fn normal_close(mut self, msg: impl Into<Utf8Bytes>) -> Result<(), Error> {
self.inner
.send(Message::Close(Some(CloseFrame {
code: 1000,
reason: msg.into(),
})))
.await
.with_kind(ErrorKind::Network)?;
while !matches!(
self.inner
.recv()
.await
.transpose()
.with_kind(ErrorKind::Network)?,
Some(Message::Close(_)) | None
) {}
Ok(())
}
/// Close the websocket connection with a result.
pub async fn close_result(
mut self,
result: Result<impl Into<Utf8Bytes> + Send, impl fmt::Display + Send>,
) -> Result<(), Error> {
match result {
Ok(msg) => self
.send(ws::Message::Close(Some(CloseFrame {
.inner
.send(Message::Close(Some(CloseFrame {
code: 1000,
reason: msg.into(),
})))
.await
.with_kind(ErrorKind::Network)?,
Err(e) => self
.send(ws::Message::Close(Some(CloseFrame {
.inner
.send(Message::Close(Some(CloseFrame {
code: 1011,
reason: e.to_string().into(),
})))
@@ -42,16 +156,51 @@ impl WebSocketExt for ws::WebSocket {
.with_kind(ErrorKind::Network)?,
}
while !matches!(
self.recv()
self.inner
.recv()
.await
.transpose()
.with_kind(ErrorKind::Network)?,
Some(ws::Message::Close(_)) | None
Some(Message::Close(_)) | None
) {}
Ok(())
}
}
impl From<AxumWebSocket> for WebSocket {
fn from(ws: AxumWebSocket) -> Self {
Self::new(ws)
}
}
impl Stream for WebSocket {
type Item = Result<Message, axum::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.get_mut().poll_recv(cx)
}
}
impl Sink<Message> for WebSocket {
type Error = axum::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.get_mut().inner.poll_ready_unpin(cx)
}
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
self.get_mut().inner.start_send_unpin(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.get_mut().inner.poll_flush_unpin(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.get_mut().inner.poll_close_unpin(cx)
}
}
pub struct SyncBody(Mutex<axum::body::BodyDataStream>);
impl From<axum::body::Body> for SyncBody {
fn from(value: axum::body::Body) -> Self {