diff --git a/core/startos/src/context/setup.rs b/core/startos/src/context/setup.rs index 49c87e001..ecef31f83 100644 --- a/core/startos/src/context/setup.rs +++ b/core/startos/src/context/setup.rs @@ -28,7 +28,6 @@ use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations}; use crate::setup::SetupProgress; use crate::shutdown::Shutdown; use crate::util::future::NonDetachingJoinHandle; -use crate::util::net::WebSocketExt; lazy_static::lazy_static! { pub static ref CURRENT_SECRET: Jwk = Jwk::generate_ec_key(josekit::jwk::alg::ec::EcCurve::P256).unwrap_or_else(|e| { diff --git a/core/startos/src/db/mod.rs b/core/startos/src/db/mod.rs index a6e28f1fe..bdcce8f7f 100644 --- a/core/startos/src/db/mod.rs +++ b/core/startos/src/db/mod.rs @@ -22,7 +22,6 @@ use ts_rs::TS; use crate::context::{CliContext, RpcContext}; use crate::prelude::*; use crate::rpc_continuations::{Guid, RpcContinuation}; -use crate::util::net::WebSocketExt; use crate::util::serde::{HandlerExtSerde, apply_expr}; lazy_static::lazy_static! { diff --git a/core/startos/src/init.rs b/core/startos/src/init.rs index 0a42f7291..37ad4374f 100644 --- a/core/startos/src/init.rs +++ b/core/startos/src/init.rs @@ -36,7 +36,6 @@ use crate::ssh::SSH_DIR; use crate::system::{get_mem_info, sync_kiosk}; use crate::util::io::{IOHook, open_file}; use crate::util::lshw::lshw; -use crate::util::net::WebSocketExt; use crate::util::{Invoke, cpupower}; use crate::{Error, MAIN_DATA, PACKAGE_DATA, ResultExt}; diff --git a/core/startos/src/install/mod.rs b/core/startos/src/install/mod.rs index 64b8ceda7..7444e5f1b 100644 --- a/core/startos/src/install/mod.rs +++ b/core/startos/src/install/mod.rs @@ -31,7 +31,6 @@ use crate::s9pk::manifest::PackageId; use crate::s9pk::v2::SIG_CONTEXT; use crate::upload::upload; use crate::util::io::open_file; -use crate::util::net::WebSocketExt; use crate::util::tui::choose; use crate::util::{FromStrParser, Never, VersionString}; diff --git a/core/startos/src/logs.rs b/core/startos/src/logs.rs index 5a2ee4587..d4c7c97da 100644 --- a/core/startos/src/logs.rs +++ b/core/startos/src/logs.rs @@ -5,13 +5,14 @@ use std::process::Stdio; use std::str::FromStr; use std::time::{Duration, UNIX_EPOCH}; -use axum::extract::ws::{self, WebSocket}; +use axum::extract::ws; +use crate::util::net::WebSocket; use chrono::{DateTime, Utc}; use clap::builder::ValueParserFactory; use clap::{Args, FromArgMatches, Parser}; use color_eyre::eyre::eyre; use futures::stream::BoxStream; -use futures::{Future, FutureExt, Stream, StreamExt, TryStreamExt}; +use futures::{Future, Stream, StreamExt, TryStreamExt}; use itertools::Itertools; use rpc_toolkit::yajrc::RpcError; use rpc_toolkit::{ @@ -30,7 +31,6 @@ use crate::context::{CliContext, RpcContext}; use crate::error::ResultExt; use crate::prelude::*; use crate::rpc_continuations::{Guid, RpcContinuation, RpcContinuations}; -use crate::util::net::WebSocketExt; use crate::util::serde::Reversible; use crate::util::{FromStrParser, Invoke}; @@ -100,8 +100,8 @@ async fn ws_handler( return stream.normal_close("complete").await; } }, - msg = stream.try_next() => { - if msg.with_kind(crate::ErrorKind::Network)?.is_none() { + msg = stream.recv() => { + if msg.transpose().with_kind(crate::ErrorKind::Network)?.is_none() { return Ok(()) } } @@ -698,16 +698,11 @@ pub async fn follow_logs>( .add( guid.clone(), RpcContinuation::ws( - Box::new(move |socket| { - ws_handler(first_entry, stream, socket) - .map(|x| match x { - Ok(_) => (), - Err(e) => { - tracing::error!("Error in log stream: {}", e); - } - }) - .boxed() - }), + move |socket| async move { + if let Err(e) = ws_handler(first_entry, stream, socket).await { + tracing::error!("Error in log stream: {}", e); + } + }, Duration::from_secs(30), ), ) diff --git a/core/startos/src/lxc/mod.rs b/core/startos/src/lxc/mod.rs index 30734c322..8091577d2 100644 --- a/core/startos/src/lxc/mod.rs +++ b/core/startos/src/lxc/mod.rs @@ -427,7 +427,7 @@ pub async fn connect(ctx: &RpcContext, container: &LxcContainer) -> Result break, Some(Ok(Message::Text(txt))) => { let mut id = None; diff --git a/core/startos/src/net/http.rs b/core/startos/src/net/http.rs index 50e765276..91a96e688 100644 --- a/core/startos/src/net/http.rs +++ b/core/startos/src/net/http.rs @@ -1,5 +1,6 @@ use std::net::IpAddr; use std::sync::Arc; +use std::time::Duration; use futures::FutureExt; use http::HeaderValue; @@ -90,11 +91,15 @@ where { let (client, to) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) .timer(TokioTimer::new()) + .keep_alive_interval(Duration::from_secs(25)) + .keep_alive_timeout(Duration::from_secs(300)) .handshake(TokioIo::new(to)) .await?; let from = hyper::server::conn::http2::Builder::new(TokioExecutor::new()) .timer(TokioTimer::new()) .enable_connect_protocol() + .keep_alive_interval(Duration::from_secs(25)) // Add this + .keep_alive_timeout(Duration::from_secs(300)) .serve_connection( TokioIo::new(from), service_fn(|mut req| { diff --git a/core/startos/src/net/socks.rs b/core/startos/src/net/socks.rs index 0cd2645d7..a035f2f7c 100644 --- a/core/startos/src/net/socks.rs +++ b/core/startos/src/net/socks.rs @@ -100,6 +100,15 @@ impl SocksController { TcpStream::connect(addr).await } } { + if let Err(e) = + socket2::SockRef::from(&target) + .set_keepalive(true) + { + tracing::error!( + "Failed to set tcp keepalive: {e}" + ); + tracing::debug!("{e:?}"); + } let mut sock = reply .reply( Reply::Succeeded, diff --git a/core/startos/src/net/tor/arti.rs b/core/startos/src/net/tor/arti.rs index 71a82b41d..8c9f54c4f 100644 --- a/core/startos/src/net/tor/arti.rs +++ b/core/startos/src/net/tor/arti.rs @@ -680,11 +680,14 @@ impl TorController { }) }) }) { - Ok(Box::new( - TcpStream::connect(target) - .await - .with_kind(ErrorKind::Network)?, - )) + let tcp_stream = TcpStream::connect(target) + .await + .with_kind(ErrorKind::Network)?; + if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) { + tracing::error!("Failed to set tcp keepalive: {e}"); + tracing::debug!("{e:?}"); + } + Ok(Box::new(tcp_stream)) } else { let mut client = self.0.client.clone(); client @@ -808,6 +811,10 @@ impl OnionService { TcpStream::connect(target) .await .with_kind(ErrorKind::Network)?; + if let Err(e) = socket2::SockRef::from(&outgoing).set_keepalive(true) { + tracing::error!("Failed to set tcp keepalive: {e}"); + tracing::debug!("{e:?}"); + } let mut incoming = req .accept(Connected::new_empty()) .await diff --git a/core/startos/src/net/tor/ctor.rs b/core/startos/src/net/tor/ctor.rs index 5ce62e137..71eb81baf 100644 --- a/core/startos/src/net/tor/ctor.rs +++ b/core/startos/src/net/tor/ctor.rs @@ -499,15 +499,22 @@ impl TorController { }) }) { tracing::debug!("Resolving {addr} internally to {target}"); - Ok(Box::new( - TcpStream::connect(target) - .await - .with_kind(ErrorKind::Network)?, - )) + let tcp_stream = TcpStream::connect(target) + .await + .with_kind(ErrorKind::Network)?; + if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) { + tracing::error!("Failed to set tcp keepalive: {e}"); + tracing::debug!("{e:?}"); + } + Ok(Box::new(tcp_stream)) } else { let mut stream = TcpStream::connect(TOR_SOCKS) .await .with_kind(ErrorKind::Tor)?; + if let Err(e) = socket2::SockRef::from(&stream).set_keepalive(true) { + tracing::error!("Failed to set tcp keepalive: {e}"); + tracing::debug!("{e:?}"); + } socks5_impl::client::connect(&mut stream, (addr.to_string(), port), None) .await .with_kind(ErrorKind::Tor)?; @@ -665,6 +672,11 @@ async fn torctl( })?; tracing::info!("Tor is started"); + if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) { + tracing::error!("Failed to set tcp keepalive: {e}"); + tracing::debug!("{e:?}"); + } + let mut conn = torut::control::UnauthenticatedConn::new(tcp_stream); let auth = conn .load_protocol_info() diff --git a/core/startos/src/net/vhost.rs b/core/startos/src/net/vhost.rs index 3f4cbd181..4996ca937 100644 --- a/core/startos/src/net/vhost.rs +++ b/core/startos/src/net/vhost.rs @@ -4,6 +4,7 @@ use std::fmt; use std::net::{IpAddr, SocketAddr}; use std::sync::{Arc, Weak}; use std::task::{Poll, ready}; +use std::time::Duration; use async_acme::acme::ACME_TLS_ALPN_NAME; use color_eyre::eyre::eyre; @@ -359,6 +360,10 @@ where .await .with_ctx(|_| (ErrorKind::Network, self.addr)) .log_err()?; + if let Err(e) = socket2::SockRef::from(&tcp_stream).set_keepalive(true) { + tracing::error!("Failed to set tcp keepalive: {e}"); + tracing::debug!("{e:?}"); + } match &self.connect_ssl { Ok(client_cfg) => { let mut client_cfg = (&**client_cfg).clone(); diff --git a/core/startos/src/net/web_server.rs b/core/startos/src/net/web_server.rs index 25e1df79a..2ac5b035f 100644 --- a/core/startos/src/net/web_server.rs +++ b/core/startos/src/net/web_server.rs @@ -106,12 +106,7 @@ impl Accept for TcpListener { cx: &mut std::task::Context<'_>, ) -> Poll> { if let Poll::Ready((stream, peer_addr)) = TcpListener::poll_accept(self, cx)? { - if let Err(e) = socket2::SockRef::from(&stream).set_tcp_keepalive( - &socket2::TcpKeepalive::new() - .with_time(Duration::from_secs(900)) - .with_interval(Duration::from_secs(60)) - .with_retries(5), - ) { + if let Err(e) = socket2::SockRef::from(&stream).set_keepalive(true) { tracing::error!("Failed to set tcp keepalive: {e}"); tracing::debug!("{e:?}"); } diff --git a/core/startos/src/rpc_continuations.rs b/core/startos/src/rpc_continuations.rs index aa3f24236..42c3ae858 100644 --- a/core/startos/src/rpc_continuations.rs +++ b/core/startos/src/rpc_continuations.rs @@ -6,7 +6,7 @@ use std::task::{Context, Poll}; use std::time::Duration; use axum::extract::Request; -use axum::extract::ws::WebSocket; +use axum::extract::ws::WebSocket as AxumWebSocket; use axum::response::Response; use clap::builder::ValueParserFactory; use futures::future::BoxFuture; @@ -18,6 +18,7 @@ use ts_rs::TS; #[allow(unused_imports)] use crate::prelude::*; use crate::util::future::TimedResource; +use crate::util::net::WebSocket; use crate::util::{FromStrParser, new_guid}; #[derive( @@ -109,7 +110,7 @@ impl Future for WebSocketFuture { } } } -pub type WebSocketHandler = Box WebSocketFuture + Send>; +pub type WebSocketHandler = Box WebSocketFuture + Send>; pub enum RpcContinuation { Rest(TimedResource), @@ -137,7 +138,7 @@ impl RpcContinuation { RpcContinuation::WebSocket(TimedResource::new( Box::new(|ws| WebSocketFuture { kill: None, - fut: handler(ws).boxed(), + fut: handler(ws.into()).boxed(), }), timeout, )) @@ -169,7 +170,7 @@ impl RpcContinuation { RpcContinuation::WebSocket(TimedResource::new( Box::new(|ws| WebSocketFuture { kill, - fut: handler(ws).boxed(), + fut: handler(ws.into()).boxed(), }), timeout, )) diff --git a/core/startos/src/service/effects/subcontainer/sync.rs b/core/startos/src/service/effects/subcontainer/sync.rs index 1fa43fcb6..dba63c1cd 100644 --- a/core/startos/src/service/effects/subcontainer/sync.rs +++ b/core/startos/src/service/effects/subcontainer/sync.rs @@ -83,9 +83,10 @@ impl procfs::FromBufRead for NSPid { fn from_buf_read(r: R) -> procfs::ProcResult { for line in r.lines() { let line = line?; - if let Some(row) = line.trim().strip_prefix("NSpid") { + if let Some(row) = line.trim().strip_prefix("NSpid:") { return Ok(Self( - row.split_ascii_whitespace() + row.trim() + .split_ascii_whitespace() .map(|pid| pid.parse::()) .collect::, _>>()?, )); @@ -205,9 +206,9 @@ impl ExecParams { split.next() }) }; - std::os::unix::fs::chown("/proc/self/fd/0", Some(uid), Some(gid)).log_err(); - std::os::unix::fs::chown("/proc/self/fd/1", Some(uid), Some(gid)).log_err(); - std::os::unix::fs::chown("/proc/self/fd/2", Some(uid), Some(gid)).log_err(); + std::os::unix::fs::chown("/proc/self/fd/0", Some(uid), Some(gid)).ok(); + std::os::unix::fs::chown("/proc/self/fd/1", Some(uid), Some(gid)).ok(); + std::os::unix::fs::chown("/proc/self/fd/2", Some(uid), Some(gid)).ok(); cmd.uid(uid); cmd.gid(gid); } else { @@ -290,8 +291,6 @@ pub fn launch( None }; - let pty_size = pty_size.or_else(|| TermSize::get_current()); - let (stdin_send, stdin_recv) = oneshot::channel::>(); std::thread::spawn(move || { if let Ok(mut cstdin) = stdin_recv.blocking_recv() { @@ -370,7 +369,7 @@ pub fn launch( .map_err(color_eyre::eyre::Report::msg) .with_ctx(|_| (ErrorKind::Filesystem, "spawning child process"))?; send_pid.send(child.id() as i32).unwrap_or_default(); - if let Some(pty_size) = pty_size { + if let Some(pty_size) = pty_size.or_else(|| TermSize::get_current()) { let size = if let Some((x, y)) = pty_size.pixels { ::pty_process::Size::new_with_pixel(pty_size.rows, pty_size.cols, x, y) } else { @@ -541,8 +540,6 @@ pub fn exec( None }; - let pty_size = pty_size.or_else(|| TermSize::get_current()); - let (stdin_send, stdin_recv) = oneshot::channel::>(); std::thread::spawn(move || { if let Ok(mut cstdin) = stdin_recv.blocking_recv() { @@ -630,7 +627,7 @@ pub fn exec( .map_err(color_eyre::eyre::Report::msg) .with_ctx(|_| (ErrorKind::Filesystem, "spawning child process"))?; send_pid.send(child.id() as i32).unwrap_or_default(); - if let Some(pty_size) = pty_size { + if let Some(pty_size) = pty_size.or_else(|| TermSize::get_current()) { let size = if let Some((x, y)) = pty_size.pixels { ::pty_process::Size::new_with_pixel(pty_size.rows, pty_size.cols, x, y) } else { diff --git a/core/startos/src/service/mod.rs b/core/startos/src/service/mod.rs index 95e53f83c..b76065296 100644 --- a/core/startos/src/service/mod.rs +++ b/core/startos/src/service/mod.rs @@ -8,7 +8,8 @@ use std::process::Stdio; use std::sync::{Arc, Weak}; use std::time::Duration; -use axum::extract::ws::{Utf8Bytes, WebSocket}; +use axum::extract::ws::Utf8Bytes; +use crate::util::net::WebSocket; use clap::Parser; use futures::future::BoxFuture; use futures::stream::FusedStream; @@ -47,7 +48,6 @@ use crate::util::Never; use crate::util::actor::concurrent::ConcurrentActor; use crate::util::future::NonDetachingJoinHandle; use crate::util::io::{AsyncReadStream, AtomicFile, TermSize, delete_file}; -use crate::util::net::WebSocketExt; use crate::util::serde::Pem; use crate::util::sync::SyncMutex; use crate::volume::data_dir; diff --git a/core/startos/src/system.rs b/core/startos/src/system.rs index 5b14611d8..f12f6195f 100644 --- a/core/startos/src/system.rs +++ b/core/startos/src/system.rs @@ -5,7 +5,7 @@ use std::time::Duration; use chrono::Utc; use clap::Parser; use color_eyre::eyre::eyre; -use futures::{FutureExt, TryStreamExt}; +use futures::FutureExt; use imbl::vector; use imbl_value::InternedString; use rpc_toolkit::{Context, Empty, HandlerExt, ParentHandler, from_fn_async}; @@ -24,7 +24,6 @@ use crate::shutdown::Shutdown; use crate::util::Invoke; use crate::util::cpupower::{Governor, get_available_governors, set_governor}; use crate::util::io::open_file; -use crate::util::net::WebSocketExt; use crate::util::serde::{HandlerExtSerde, WithIoFormat, display_serializable}; use crate::util::sync::Watch; use crate::{MAIN_DATA, PACKAGE_DATA}; @@ -527,8 +526,8 @@ pub async fn metrics_follow( .into(), )).await.with_kind(ErrorKind::Network)?; } - msg = ws.try_next() => { - if msg.with_kind(crate::ErrorKind::Network)?.is_none() { + msg = ws.recv() => { + if msg.transpose().with_kind(crate::ErrorKind::Network)?.is_none() { break; } } diff --git a/core/startos/src/tunnel/db.rs b/core/startos/src/tunnel/db.rs index 863943543..848e8481a 100644 --- a/core/startos/src/tunnel/db.rs +++ b/core/startos/src/tunnel/db.rs @@ -27,7 +27,6 @@ use crate::tunnel::auth::SignerInfo; use crate::tunnel::context::TunnelContext; use crate::tunnel::web::WebserverInfo; use crate::tunnel::wg::WgServer; -use crate::util::net::WebSocketExt; use crate::util::serde::{HandlerExtSerde, apply_expr}; #[derive(Default, Deserialize, Serialize, HasModel, TS)] diff --git a/core/startos/src/update/mod.rs b/core/startos/src/update/mod.rs index 44d67e3c9..3595cb42a 100644 --- a/core/startos/src/update/mod.rs +++ b/core/startos/src/update/mod.rs @@ -37,7 +37,6 @@ use crate::sound::{ use crate::util::Invoke; use crate::util::future::NonDetachingJoinHandle; use crate::util::io::AtomicFile; -use crate::util::net::WebSocketExt; #[derive(Deserialize, Serialize, Parser, TS)] #[serde(rename_all = "camelCase")] diff --git a/core/startos/src/util/net.rs b/core/startos/src/util/net.rs index a44edd1ad..53168d310 100644 --- a/core/startos/src/util/net.rs +++ b/core/startos/src/util/net.rs @@ -1,40 +1,154 @@ use core::fmt; +use std::pin::Pin; use std::sync::Mutex; +use std::task::{Context, Poll, ready}; +use std::time::Duration; -use axum::extract::ws::{self, CloseFrame, Utf8Bytes}; -use futures::{Future, Stream, StreamExt}; +use axum::extract::ws::{CloseFrame, Message, Utf8Bytes, WebSocket as AxumWebSocket}; +use futures::{Sink, SinkExt, Stream, StreamExt}; +use tokio::time::{Instant, Sleep}; use crate::prelude::*; -pub trait WebSocketExt { - fn normal_close( - self, - msg: impl Into + Send, - ) -> impl Future> + Send; - fn close_result( - self, - result: Result + Send, impl fmt::Display + Send>, - ) -> impl Future> + Send; +const PING_INTERVAL: Duration = Duration::from_secs(30); +const PING_TIMEOUT: Duration = Duration::from_secs(300); + +/// A wrapper around axum's WebSocket that automatically sends ping frames +/// to keep the connection alive during HTTP/2. +/// +/// HTTP/2 streams can timeout if idle, even when the underlying connection +/// has keep-alive enabled. This wrapper sends a ping frame if no data has +/// been sent within the ping interval while waiting to receive a message. +pub struct WebSocket { + inner: AxumWebSocket, + ping_state: Option<(bool, u64)>, + next_ping: Pin>, + fused: bool, } -impl WebSocketExt for ws::WebSocket { - async fn normal_close(self, msg: impl Into + Send) -> Result<(), Error> { - self.close_result(Ok::<_, Error>(msg)).await +impl WebSocket { + pub fn new(ws: AxumWebSocket) -> Self { + Self { + inner: ws, + ping_state: None, + next_ping: Box::pin(tokio::time::sleep(PING_INTERVAL)), + fused: false, + } } - async fn close_result( + + pub fn into_inner(self) -> AxumWebSocket { + self.inner + } + + pub async fn send(&mut self, msg: Message) -> Result<(), axum::Error> { + if self.ping_state.is_none() { + self.next_ping + .as_mut() + .reset(Instant::now() + PING_INTERVAL); + } + self.inner.send(msg).await + } + + pub fn poll_recv( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.fused { + return Poll::Ready(None); + } + let mut inner = Pin::new(&mut self.inner); + + loop { + if let Poll::Ready(msg) = inner.as_mut().poll_next(cx) { + if self.ping_state.is_none() { + self.next_ping + .as_mut() + .reset(Instant::now() + PING_INTERVAL); + } + + if let Some(Ok(Message::Pong(x))) = &msg { + if let Some((true, id)) = self.ping_state { + if &u64::to_be_bytes(id)[..] == &**x { + self.ping_state.take(); + self.next_ping + .as_mut() + .reset(Instant::now() + PING_INTERVAL); + continue; + } + } + } + + break Poll::Ready(msg); + } + + if let Some((sent, id)) = &mut self.ping_state { + if !*sent { + ready!(inner.as_mut().poll_ready(cx))?; + inner + .as_mut() + .start_send(Message::Ping((u64::to_be_bytes(*id).to_vec()).into()))?; + self.next_ping.as_mut().reset(Instant::now() + PING_TIMEOUT); + *sent = true; + } + ready!(inner.as_mut().poll_flush(cx))?; + } + + ready!(self.next_ping.as_mut().poll(cx)); + if self.ping_state.is_some() { + self.fused = true; + break Poll::Ready(Some(Err(axum::Error::new(eyre!( + "Timeout: WebSocket did not respond to ping within {PING_TIMEOUT:?}" + ))))); + } + self.ping_state = Some((false, rand::random())); + } + } + + /// Receive a message from the websocket, automatically sending ping frames + /// if the connection is idle for too long. + /// + /// Ping and Pong frames are handled internally and not returned to the caller. + pub async fn recv(&mut self) -> Option> { + futures::future::poll_fn(|cx| self.poll_recv(cx)).await + } + + /// Close the websocket connection normally. + pub async fn normal_close(mut self, msg: impl Into) -> Result<(), Error> { + self.inner + .send(Message::Close(Some(CloseFrame { + code: 1000, + reason: msg.into(), + }))) + .await + .with_kind(ErrorKind::Network)?; + while !matches!( + self.inner + .recv() + .await + .transpose() + .with_kind(ErrorKind::Network)?, + Some(Message::Close(_)) | None + ) {} + Ok(()) + } + + /// Close the websocket connection with a result. + pub async fn close_result( mut self, result: Result + Send, impl fmt::Display + Send>, ) -> Result<(), Error> { match result { Ok(msg) => self - .send(ws::Message::Close(Some(CloseFrame { + .inner + .send(Message::Close(Some(CloseFrame { code: 1000, reason: msg.into(), }))) .await .with_kind(ErrorKind::Network)?, Err(e) => self - .send(ws::Message::Close(Some(CloseFrame { + .inner + .send(Message::Close(Some(CloseFrame { code: 1011, reason: e.to_string().into(), }))) @@ -42,16 +156,51 @@ impl WebSocketExt for ws::WebSocket { .with_kind(ErrorKind::Network)?, } while !matches!( - self.recv() + self.inner + .recv() .await .transpose() .with_kind(ErrorKind::Network)?, - Some(ws::Message::Close(_)) | None + Some(Message::Close(_)) | None ) {} Ok(()) } } +impl From for WebSocket { + fn from(ws: AxumWebSocket) -> Self { + Self::new(ws) + } +} + +impl Stream for WebSocket { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().poll_recv(cx) + } +} + +impl Sink for WebSocket { + type Error = axum::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().inner.poll_ready_unpin(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + self.get_mut().inner.start_send_unpin(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().inner.poll_flush_unpin(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().inner.poll_close_unpin(cx) + } +} + pub struct SyncBody(Mutex); impl From for SyncBody { fn from(value: axum::body::Body) -> Self {