mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-03-26 02:11:53 +00:00
fix ws timeouts
This commit is contained in:
@@ -28,7 +28,6 @@ use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations};
|
||||
use crate::setup::SetupProgress;
|
||||
use crate::shutdown::Shutdown;
|
||||
use crate::util::future::NonDetachingJoinHandle;
|
||||
use crate::util::net::WebSocketExt;
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
pub static ref CURRENT_SECRET: Jwk = Jwk::generate_ec_key(josekit::jwk::alg::ec::EcCurve::P256).unwrap_or_else(|e| {
|
||||
|
||||
@@ -22,7 +22,6 @@ use ts_rs::TS;
|
||||
use crate::context::{CliContext, RpcContext};
|
||||
use crate::prelude::*;
|
||||
use crate::rpc_continuations::{Guid, RpcContinuation};
|
||||
use crate::util::net::WebSocketExt;
|
||||
use crate::util::serde::{HandlerExtSerde, apply_expr};
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
|
||||
@@ -36,7 +36,6 @@ use crate::ssh::SSH_DIR;
|
||||
use crate::system::{get_mem_info, sync_kiosk};
|
||||
use crate::util::io::{IOHook, open_file};
|
||||
use crate::util::lshw::lshw;
|
||||
use crate::util::net::WebSocketExt;
|
||||
use crate::util::{Invoke, cpupower};
|
||||
use crate::{Error, MAIN_DATA, PACKAGE_DATA, ResultExt};
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ use crate::s9pk::manifest::PackageId;
|
||||
use crate::s9pk::v2::SIG_CONTEXT;
|
||||
use crate::upload::upload;
|
||||
use crate::util::io::open_file;
|
||||
use crate::util::net::WebSocketExt;
|
||||
use crate::util::tui::choose;
|
||||
use crate::util::{FromStrParser, Never, VersionString};
|
||||
|
||||
|
||||
@@ -5,13 +5,14 @@ use std::process::Stdio;
|
||||
use std::str::FromStr;
|
||||
use std::time::{Duration, UNIX_EPOCH};
|
||||
|
||||
use axum::extract::ws::{self, WebSocket};
|
||||
use axum::extract::ws;
|
||||
use crate::util::net::WebSocket;
|
||||
use chrono::{DateTime, Utc};
|
||||
use clap::builder::ValueParserFactory;
|
||||
use clap::{Args, FromArgMatches, Parser};
|
||||
use color_eyre::eyre::eyre;
|
||||
use futures::stream::BoxStream;
|
||||
use futures::{Future, FutureExt, Stream, StreamExt, TryStreamExt};
|
||||
use futures::{Future, Stream, StreamExt, TryStreamExt};
|
||||
use itertools::Itertools;
|
||||
use rpc_toolkit::yajrc::RpcError;
|
||||
use rpc_toolkit::{
|
||||
@@ -30,7 +31,6 @@ use crate::context::{CliContext, RpcContext};
|
||||
use crate::error::ResultExt;
|
||||
use crate::prelude::*;
|
||||
use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations};
|
||||
use crate::util::net::WebSocketExt;
|
||||
use crate::util::serde::Reversible;
|
||||
use crate::util::{FromStrParser, Invoke};
|
||||
|
||||
@@ -100,8 +100,8 @@ async fn ws_handler(
|
||||
return stream.normal_close("complete").await;
|
||||
}
|
||||
},
|
||||
msg = stream.try_next() => {
|
||||
if msg.with_kind(crate::ErrorKind::Network)?.is_none() {
|
||||
msg = stream.recv() => {
|
||||
if msg.transpose().with_kind(crate::ErrorKind::Network)?.is_none() {
|
||||
return Ok(())
|
||||
}
|
||||
}
|
||||
@@ -698,16 +698,11 @@ pub async fn follow_logs<Context: AsRef<RpcContinuations>>(
|
||||
.add(
|
||||
guid.clone(),
|
||||
RpcContinuation::ws(
|
||||
Box::new(move |socket| {
|
||||
ws_handler(first_entry, stream, socket)
|
||||
.map(|x| match x {
|
||||
Ok(_) => (),
|
||||
Err(e) => {
|
||||
tracing::error!("Error in log stream: {}", e);
|
||||
}
|
||||
})
|
||||
.boxed()
|
||||
}),
|
||||
move |socket| async move {
|
||||
if let Err(e) = ws_handler(first_entry, stream, socket).await {
|
||||
tracing::error!("Error in log stream: {}", e);
|
||||
}
|
||||
},
|
||||
Duration::from_secs(30),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -427,7 +427,7 @@ pub async fn connect(ctx: &RpcContext, container: &LxcContainer) -> Result<Guid,
|
||||
|mut ws| async move {
|
||||
if let Err(e) = async {
|
||||
loop {
|
||||
match ws.next().await {
|
||||
match ws.recv().await {
|
||||
None => break,
|
||||
Some(Ok(Message::Text(txt))) => {
|
||||
let mut id = None;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::FutureExt;
|
||||
use http::HeaderValue;
|
||||
@@ -90,11 +91,15 @@ where
|
||||
{
|
||||
let (client, to) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||
.timer(TokioTimer::new())
|
||||
.keep_alive_interval(Duration::from_secs(25))
|
||||
.keep_alive_timeout(Duration::from_secs(300))
|
||||
.handshake(TokioIo::new(to))
|
||||
.await?;
|
||||
let from = hyper::server::conn::http2::Builder::new(TokioExecutor::new())
|
||||
.timer(TokioTimer::new())
|
||||
.enable_connect_protocol()
|
||||
.keep_alive_interval(Duration::from_secs(25)) // Add this
|
||||
.keep_alive_timeout(Duration::from_secs(300))
|
||||
.serve_connection(
|
||||
TokioIo::new(from),
|
||||
service_fn(|mut req| {
|
||||
|
||||
@@ -100,6 +100,15 @@ impl SocksController {
|
||||
TcpStream::connect(addr).await
|
||||
}
|
||||
} {
|
||||
if let Err(e) =
|
||||
socket2::SockRef::from(&target)
|
||||
.set_keepalive(true)
|
||||
{
|
||||
tracing::error!(
|
||||
"Failed to set tcp keepalive: {e}"
|
||||
);
|
||||
tracing::debug!("{e:?}");
|
||||
}
|
||||
let mut sock = reply
|
||||
.reply(
|
||||
Reply::Succeeded,
|
||||
|
||||
@@ -680,11 +680,14 @@ impl TorController {
|
||||
})
|
||||
})
|
||||
}) {
|
||||
Ok(Box::new(
|
||||
TcpStream::connect(target)
|
||||
.await
|
||||
.with_kind(ErrorKind::Network)?,
|
||||
))
|
||||
let tcp_stream = TcpStream::connect(target)
|
||||
.await
|
||||
.with_kind(ErrorKind::Network)?;
|
||||
if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) {
|
||||
tracing::error!("Failed to set tcp keepalive: {e}");
|
||||
tracing::debug!("{e:?}");
|
||||
}
|
||||
Ok(Box::new(tcp_stream))
|
||||
} else {
|
||||
let mut client = self.0.client.clone();
|
||||
client
|
||||
@@ -808,6 +811,10 @@ impl OnionService {
|
||||
TcpStream::connect(target)
|
||||
.await
|
||||
.with_kind(ErrorKind::Network)?;
|
||||
if let Err(e) = socket2::SockRef::from(&outgoing).set_keepalive(true) {
|
||||
tracing::error!("Failed to set tcp keepalive: {e}");
|
||||
tracing::debug!("{e:?}");
|
||||
}
|
||||
let mut incoming = req
|
||||
.accept(Connected::new_empty())
|
||||
.await
|
||||
|
||||
@@ -499,15 +499,22 @@ impl TorController {
|
||||
})
|
||||
}) {
|
||||
tracing::debug!("Resolving {addr} internally to {target}");
|
||||
Ok(Box::new(
|
||||
TcpStream::connect(target)
|
||||
.await
|
||||
.with_kind(ErrorKind::Network)?,
|
||||
))
|
||||
let tcp_stream = TcpStream::connect(target)
|
||||
.await
|
||||
.with_kind(ErrorKind::Network)?;
|
||||
if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) {
|
||||
tracing::error!("Failed to set tcp keepalive: {e}");
|
||||
tracing::debug!("{e:?}");
|
||||
}
|
||||
Ok(Box::new(tcp_stream))
|
||||
} else {
|
||||
let mut stream = TcpStream::connect(TOR_SOCKS)
|
||||
.await
|
||||
.with_kind(ErrorKind::Tor)?;
|
||||
if let Err(e) = socket2::SockRef::from(&stream).set_keepalive(true) {
|
||||
tracing::error!("Failed to set tcp keepalive: {e}");
|
||||
tracing::debug!("{e:?}");
|
||||
}
|
||||
socks5_impl::client::connect(&mut stream, (addr.to_string(), port), None)
|
||||
.await
|
||||
.with_kind(ErrorKind::Tor)?;
|
||||
@@ -665,6 +672,11 @@ async fn torctl(
|
||||
})?;
|
||||
tracing::info!("Tor is started");
|
||||
|
||||
if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) {
|
||||
tracing::error!("Failed to set tcp keepalive: {e}");
|
||||
tracing::debug!("{e:?}");
|
||||
}
|
||||
|
||||
let mut conn = torut::control::UnauthenticatedConn::new(tcp_stream);
|
||||
let auth = conn
|
||||
.load_protocol_info()
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::fmt;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::task::{Poll, ready};
|
||||
use std::time::Duration;
|
||||
|
||||
use async_acme::acme::ACME_TLS_ALPN_NAME;
|
||||
use color_eyre::eyre::eyre;
|
||||
@@ -359,6 +360,10 @@ where
|
||||
.await
|
||||
.with_ctx(|_| (ErrorKind::Network, self.addr))
|
||||
.log_err()?;
|
||||
if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) {
|
||||
tracing::error!("Failed to set tcp keepalive: {e}");
|
||||
tracing::debug!("{e:?}");
|
||||
}
|
||||
match &self.connect_ssl {
|
||||
Ok(client_cfg) => {
|
||||
let mut client_cfg = (&**client_cfg).clone();
|
||||
|
||||
@@ -106,12 +106,7 @@ impl Accept for TcpListener {
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(Self::Metadata, AcceptStream), Error>> {
|
||||
if let Poll::Ready((stream, peer_addr)) = TcpListener::poll_accept(self, cx)? {
|
||||
if let Err(e) = socket2::SockRef::from(&stream).set_tcp_keepalive(
|
||||
&socket2::TcpKeepalive::new()
|
||||
.with_time(Duration::from_secs(900))
|
||||
.with_interval(Duration::from_secs(60))
|
||||
.with_retries(5),
|
||||
) {
|
||||
if let Err(e) = socket2::SockRef::from(&stream).set_keepalive(true) {
|
||||
tracing::error!("Failed to set tcp keepalive: {e}");
|
||||
tracing::debug!("{e:?}");
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::extract::Request;
|
||||
use axum::extract::ws::WebSocket;
|
||||
use axum::extract::ws::WebSocket as AxumWebSocket;
|
||||
use axum::response::Response;
|
||||
use clap::builder::ValueParserFactory;
|
||||
use futures::future::BoxFuture;
|
||||
@@ -18,6 +18,7 @@ use ts_rs::TS;
|
||||
#[allow(unused_imports)]
|
||||
use crate::prelude::*;
|
||||
use crate::util::future::TimedResource;
|
||||
use crate::util::net::WebSocket;
|
||||
use crate::util::{FromStrParser, new_guid};
|
||||
|
||||
#[derive(
|
||||
@@ -109,7 +110,7 @@ impl Future for WebSocketFuture {
|
||||
}
|
||||
}
|
||||
}
|
||||
pub type WebSocketHandler = Box<dyn FnOnce(WebSocket) -> WebSocketFuture + Send>;
|
||||
pub type WebSocketHandler = Box<dyn FnOnce(AxumWebSocket) -> WebSocketFuture + Send>;
|
||||
|
||||
pub enum RpcContinuation {
|
||||
Rest(TimedResource<RestHandler>),
|
||||
@@ -137,7 +138,7 @@ impl RpcContinuation {
|
||||
RpcContinuation::WebSocket(TimedResource::new(
|
||||
Box::new(|ws| WebSocketFuture {
|
||||
kill: None,
|
||||
fut: handler(ws).boxed(),
|
||||
fut: handler(ws.into()).boxed(),
|
||||
}),
|
||||
timeout,
|
||||
))
|
||||
@@ -169,7 +170,7 @@ impl RpcContinuation {
|
||||
RpcContinuation::WebSocket(TimedResource::new(
|
||||
Box::new(|ws| WebSocketFuture {
|
||||
kill,
|
||||
fut: handler(ws).boxed(),
|
||||
fut: handler(ws.into()).boxed(),
|
||||
}),
|
||||
timeout,
|
||||
))
|
||||
|
||||
@@ -83,9 +83,10 @@ impl procfs::FromBufRead for NSPid {
|
||||
fn from_buf_read<R: std::io::BufRead>(r: R) -> procfs::ProcResult<Self> {
|
||||
for line in r.lines() {
|
||||
let line = line?;
|
||||
if let Some(row) = line.trim().strip_prefix("NSpid") {
|
||||
if let Some(row) = line.trim().strip_prefix("NSpid:") {
|
||||
return Ok(Self(
|
||||
row.split_ascii_whitespace()
|
||||
row.trim()
|
||||
.split_ascii_whitespace()
|
||||
.map(|pid| pid.parse::<i32>())
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
));
|
||||
@@ -205,9 +206,9 @@ impl ExecParams {
|
||||
split.next()
|
||||
})
|
||||
};
|
||||
std::os::unix::fs::chown("/proc/self/fd/0", Some(uid), Some(gid)).log_err();
|
||||
std::os::unix::fs::chown("/proc/self/fd/1", Some(uid), Some(gid)).log_err();
|
||||
std::os::unix::fs::chown("/proc/self/fd/2", Some(uid), Some(gid)).log_err();
|
||||
std::os::unix::fs::chown("/proc/self/fd/0", Some(uid), Some(gid)).ok();
|
||||
std::os::unix::fs::chown("/proc/self/fd/1", Some(uid), Some(gid)).ok();
|
||||
std::os::unix::fs::chown("/proc/self/fd/2", Some(uid), Some(gid)).ok();
|
||||
cmd.uid(uid);
|
||||
cmd.gid(gid);
|
||||
} else {
|
||||
@@ -290,8 +291,6 @@ pub fn launch(
|
||||
None
|
||||
};
|
||||
|
||||
let pty_size = pty_size.or_else(|| TermSize::get_current());
|
||||
|
||||
let (stdin_send, stdin_recv) = oneshot::channel::<Box<dyn Write + Send>>();
|
||||
std::thread::spawn(move || {
|
||||
if let Ok(mut cstdin) = stdin_recv.blocking_recv() {
|
||||
@@ -370,7 +369,7 @@ pub fn launch(
|
||||
.map_err(color_eyre::eyre::Report::msg)
|
||||
.with_ctx(|_| (ErrorKind::Filesystem, "spawning child process"))?;
|
||||
send_pid.send(child.id() as i32).unwrap_or_default();
|
||||
if let Some(pty_size) = pty_size {
|
||||
if let Some(pty_size) = pty_size.or_else(|| TermSize::get_current()) {
|
||||
let size = if let Some((x, y)) = pty_size.pixels {
|
||||
::pty_process::Size::new_with_pixel(pty_size.rows, pty_size.cols, x, y)
|
||||
} else {
|
||||
@@ -541,8 +540,6 @@ pub fn exec(
|
||||
None
|
||||
};
|
||||
|
||||
let pty_size = pty_size.or_else(|| TermSize::get_current());
|
||||
|
||||
let (stdin_send, stdin_recv) = oneshot::channel::<Box<dyn Write + Send>>();
|
||||
std::thread::spawn(move || {
|
||||
if let Ok(mut cstdin) = stdin_recv.blocking_recv() {
|
||||
@@ -630,7 +627,7 @@ pub fn exec(
|
||||
.map_err(color_eyre::eyre::Report::msg)
|
||||
.with_ctx(|_| (ErrorKind::Filesystem, "spawning child process"))?;
|
||||
send_pid.send(child.id() as i32).unwrap_or_default();
|
||||
if let Some(pty_size) = pty_size {
|
||||
if let Some(pty_size) = pty_size.or_else(|| TermSize::get_current()) {
|
||||
let size = if let Some((x, y)) = pty_size.pixels {
|
||||
::pty_process::Size::new_with_pixel(pty_size.rows, pty_size.cols, x, y)
|
||||
} else {
|
||||
|
||||
@@ -8,7 +8,8 @@ use std::process::Stdio;
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::extract::ws::{Utf8Bytes, WebSocket};
|
||||
use axum::extract::ws::Utf8Bytes;
|
||||
use crate::util::net::WebSocket;
|
||||
use clap::Parser;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::stream::FusedStream;
|
||||
@@ -47,7 +48,6 @@ use crate::util::Never;
|
||||
use crate::util::actor::concurrent::ConcurrentActor;
|
||||
use crate::util::future::NonDetachingJoinHandle;
|
||||
use crate::util::io::{AsyncReadStream, AtomicFile, TermSize, delete_file};
|
||||
use crate::util::net::WebSocketExt;
|
||||
use crate::util::serde::Pem;
|
||||
use crate::util::sync::SyncMutex;
|
||||
use crate::volume::data_dir;
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::time::Duration;
|
||||
use chrono::Utc;
|
||||
use clap::Parser;
|
||||
use color_eyre::eyre::eyre;
|
||||
use futures::{FutureExt, TryStreamExt};
|
||||
use futures::FutureExt;
|
||||
use imbl::vector;
|
||||
use imbl_value::InternedString;
|
||||
use rpc_toolkit::{Context, Empty, HandlerExt, ParentHandler, from_fn_async};
|
||||
@@ -24,7 +24,6 @@ use crate::shutdown::Shutdown;
|
||||
use crate::util::Invoke;
|
||||
use crate::util::cpupower::{Governor, get_available_governors, set_governor};
|
||||
use crate::util::io::open_file;
|
||||
use crate::util::net::WebSocketExt;
|
||||
use crate::util::serde::{HandlerExtSerde, WithIoFormat, display_serializable};
|
||||
use crate::util::sync::Watch;
|
||||
use crate::{MAIN_DATA, PACKAGE_DATA};
|
||||
@@ -527,8 +526,8 @@ pub async fn metrics_follow(
|
||||
.into(),
|
||||
)).await.with_kind(ErrorKind::Network)?;
|
||||
}
|
||||
msg = ws.try_next() => {
|
||||
if msg.with_kind(crate::ErrorKind::Network)?.is_none() {
|
||||
msg = ws.recv() => {
|
||||
if msg.transpose().with_kind(crate::ErrorKind::Network)?.is_none() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,7 +27,6 @@ use crate::tunnel::auth::SignerInfo;
|
||||
use crate::tunnel::context::TunnelContext;
|
||||
use crate::tunnel::web::WebserverInfo;
|
||||
use crate::tunnel::wg::WgServer;
|
||||
use crate::util::net::WebSocketExt;
|
||||
use crate::util::serde::{HandlerExtSerde, apply_expr};
|
||||
|
||||
#[derive(Default, Deserialize, Serialize, HasModel, TS)]
|
||||
|
||||
@@ -37,7 +37,6 @@ use crate::sound::{
|
||||
use crate::util::Invoke;
|
||||
use crate::util::future::NonDetachingJoinHandle;
|
||||
use crate::util::io::AtomicFile;
|
||||
use crate::util::net::WebSocketExt;
|
||||
|
||||
#[derive(Deserialize, Serialize, Parser, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
|
||||
@@ -1,40 +1,154 @@
|
||||
use core::fmt;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Mutex;
|
||||
use std::task::{Context, Poll, ready};
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::extract::ws::{self, CloseFrame, Utf8Bytes};
|
||||
use futures::{Future, Stream, StreamExt};
|
||||
use axum::extract::ws::{CloseFrame, Message, Utf8Bytes, WebSocket as AxumWebSocket};
|
||||
use futures::{Sink, SinkExt, Stream, StreamExt};
|
||||
use tokio::time::{Instant, Sleep};
|
||||
|
||||
use crate::prelude::*;
|
||||
|
||||
pub trait WebSocketExt {
|
||||
fn normal_close(
|
||||
self,
|
||||
msg: impl Into<Utf8Bytes> + Send,
|
||||
) -> impl Future<Output = Result<(), Error>> + Send;
|
||||
fn close_result(
|
||||
self,
|
||||
result: Result<impl Into<Utf8Bytes> + Send, impl fmt::Display + Send>,
|
||||
) -> impl Future<Output = Result<(), Error>> + Send;
|
||||
const PING_INTERVAL: Duration = Duration::from_secs(30);
|
||||
const PING_TIMEOUT: Duration = Duration::from_secs(300);
|
||||
|
||||
/// A wrapper around axum's WebSocket that automatically sends ping frames
|
||||
/// to keep the connection alive during HTTP/2.
|
||||
///
|
||||
/// HTTP/2 streams can timeout if idle, even when the underlying connection
|
||||
/// has keep-alive enabled. This wrapper sends a ping frame if no data has
|
||||
/// been sent within the ping interval while waiting to receive a message.
|
||||
pub struct WebSocket {
|
||||
inner: AxumWebSocket,
|
||||
ping_state: Option<(bool, u64)>,
|
||||
next_ping: Pin<Box<Sleep>>,
|
||||
fused: bool,
|
||||
}
|
||||
|
||||
impl WebSocketExt for ws::WebSocket {
|
||||
async fn normal_close(self, msg: impl Into<Utf8Bytes> + Send) -> Result<(), Error> {
|
||||
self.close_result(Ok::<_, Error>(msg)).await
|
||||
impl WebSocket {
|
||||
pub fn new(ws: AxumWebSocket) -> Self {
|
||||
Self {
|
||||
inner: ws,
|
||||
ping_state: None,
|
||||
next_ping: Box::pin(tokio::time::sleep(PING_INTERVAL)),
|
||||
fused: false,
|
||||
}
|
||||
}
|
||||
async fn close_result(
|
||||
|
||||
pub fn into_inner(self) -> AxumWebSocket {
|
||||
self.inner
|
||||
}
|
||||
|
||||
pub async fn send(&mut self, msg: Message) -> Result<(), axum::Error> {
|
||||
if self.ping_state.is_none() {
|
||||
self.next_ping
|
||||
.as_mut()
|
||||
.reset(Instant::now() + PING_INTERVAL);
|
||||
}
|
||||
self.inner.send(msg).await
|
||||
}
|
||||
|
||||
pub fn poll_recv(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Message, axum::Error>>> {
|
||||
if self.fused {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
let mut inner = Pin::new(&mut self.inner);
|
||||
|
||||
loop {
|
||||
if let Poll::Ready(msg) = inner.as_mut().poll_next(cx) {
|
||||
if self.ping_state.is_none() {
|
||||
self.next_ping
|
||||
.as_mut()
|
||||
.reset(Instant::now() + PING_INTERVAL);
|
||||
}
|
||||
|
||||
if let Some(Ok(Message::Pong(x))) = &msg {
|
||||
if let Some((true, id)) = self.ping_state {
|
||||
if &u64::to_be_bytes(id)[..] == &**x {
|
||||
self.ping_state.take();
|
||||
self.next_ping
|
||||
.as_mut()
|
||||
.reset(Instant::now() + PING_INTERVAL);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
break Poll::Ready(msg);
|
||||
}
|
||||
|
||||
if let Some((sent, id)) = &mut self.ping_state {
|
||||
if !*sent {
|
||||
ready!(inner.as_mut().poll_ready(cx))?;
|
||||
inner
|
||||
.as_mut()
|
||||
.start_send(Message::Ping((u64::to_be_bytes(*id).to_vec()).into()))?;
|
||||
self.next_ping.as_mut().reset(Instant::now() + PING_TIMEOUT);
|
||||
*sent = true;
|
||||
}
|
||||
ready!(inner.as_mut().poll_flush(cx))?;
|
||||
}
|
||||
|
||||
ready!(self.next_ping.as_mut().poll(cx));
|
||||
if self.ping_state.is_some() {
|
||||
self.fused = true;
|
||||
break Poll::Ready(Some(Err(axum::Error::new(eyre!(
|
||||
"Timeout: WebSocket did not respond to ping within {PING_TIMEOUT:?}"
|
||||
)))));
|
||||
}
|
||||
self.ping_state = Some((false, rand::random()));
|
||||
}
|
||||
}
|
||||
|
||||
/// Receive a message from the websocket, automatically sending ping frames
|
||||
/// if the connection is idle for too long.
|
||||
///
|
||||
/// Ping and Pong frames are handled internally and not returned to the caller.
|
||||
pub async fn recv(&mut self) -> Option<Result<Message, axum::Error>> {
|
||||
futures::future::poll_fn(|cx| self.poll_recv(cx)).await
|
||||
}
|
||||
|
||||
/// Close the websocket connection normally.
|
||||
pub async fn normal_close(mut self, msg: impl Into<Utf8Bytes>) -> Result<(), Error> {
|
||||
self.inner
|
||||
.send(Message::Close(Some(CloseFrame {
|
||||
code: 1000,
|
||||
reason: msg.into(),
|
||||
})))
|
||||
.await
|
||||
.with_kind(ErrorKind::Network)?;
|
||||
while !matches!(
|
||||
self.inner
|
||||
.recv()
|
||||
.await
|
||||
.transpose()
|
||||
.with_kind(ErrorKind::Network)?,
|
||||
Some(Message::Close(_)) | None
|
||||
) {}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Close the websocket connection with a result.
|
||||
pub async fn close_result(
|
||||
mut self,
|
||||
result: Result<impl Into<Utf8Bytes> + Send, impl fmt::Display + Send>,
|
||||
) -> Result<(), Error> {
|
||||
match result {
|
||||
Ok(msg) => self
|
||||
.send(ws::Message::Close(Some(CloseFrame {
|
||||
.inner
|
||||
.send(Message::Close(Some(CloseFrame {
|
||||
code: 1000,
|
||||
reason: msg.into(),
|
||||
})))
|
||||
.await
|
||||
.with_kind(ErrorKind::Network)?,
|
||||
Err(e) => self
|
||||
.send(ws::Message::Close(Some(CloseFrame {
|
||||
.inner
|
||||
.send(Message::Close(Some(CloseFrame {
|
||||
code: 1011,
|
||||
reason: e.to_string().into(),
|
||||
})))
|
||||
@@ -42,16 +156,51 @@ impl WebSocketExt for ws::WebSocket {
|
||||
.with_kind(ErrorKind::Network)?,
|
||||
}
|
||||
while !matches!(
|
||||
self.recv()
|
||||
self.inner
|
||||
.recv()
|
||||
.await
|
||||
.transpose()
|
||||
.with_kind(ErrorKind::Network)?,
|
||||
Some(ws::Message::Close(_)) | None
|
||||
Some(Message::Close(_)) | None
|
||||
) {}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AxumWebSocket> for WebSocket {
|
||||
fn from(ws: AxumWebSocket) -> Self {
|
||||
Self::new(ws)
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for WebSocket {
|
||||
type Item = Result<Message, axum::Error>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.get_mut().poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<Message> for WebSocket {
|
||||
type Error = axum::Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().inner.poll_ready_unpin(cx)
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
|
||||
self.get_mut().inner.start_send_unpin(item)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().inner.poll_flush_unpin(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.get_mut().inner.poll_close_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SyncBody(Mutex<axum::body::BodyDataStream>);
|
||||
impl From<axum::body::Body> for SyncBody {
|
||||
fn from(value: axum::body::Body) -> Self {
|
||||
|
||||
Reference in New Issue
Block a user