diff --git a/appmgr/src/context/rpc.rs b/appmgr/src/context/rpc.rs index cf460c9f9..b119fc673 100644 --- a/appmgr/src/context/rpc.rs +++ b/appmgr/src/context/rpc.rs @@ -1,4 +1,4 @@ -use std::collections::VecDeque; +use std::collections::{BTreeMap, VecDeque}; use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; use std::ops::Deref; use std::path::{Path, PathBuf}; @@ -16,13 +16,13 @@ use serde::Deserialize; use sqlx::sqlite::SqliteConnectOptions; use sqlx::SqlitePool; use tokio::fs::File; -use tokio::sync::broadcast::Sender; -use tokio::sync::RwLock; +use tokio::sync::{broadcast, oneshot, Mutex, RwLock}; use tracing::instrument; use crate::db::model::Database; use crate::hostname::{get_hostname, get_id}; use crate::manager::ManagerMap; +use crate::middleware::auth::HashSessionToken; use crate::net::tor::os_key; use crate::net::NetController; use crate::notifications::NotificationManager; @@ -115,12 +115,13 @@ pub struct RpcContextSeed { pub revision_cache_size: usize, pub revision_cache: RwLock>>, pub metrics_cache: RwLock>, - pub shutdown: Sender>, + pub shutdown: broadcast::Sender>, pub websocket_count: AtomicUsize, pub logger: EmbassyLogger, pub log_epoch: Arc, pub tor_socks: SocketAddr, pub notification_manager: NotificationManager, + pub open_authed_websockets: Mutex>>>, } #[derive(Clone)] @@ -185,6 +186,7 @@ impl RpcContext { 9050, ))), notification_manager, + open_authed_websockets: Mutex::new(BTreeMap::new()), }); let metrics_seed = seed.clone(); tokio::spawn(async move { diff --git a/appmgr/src/db/mod.rs b/appmgr/src/db/mod.rs index 7a8618c4e..de4d51d83 100644 --- a/appmgr/src/db/mod.rs +++ b/appmgr/src/db/mod.rs @@ -16,6 +16,7 @@ use rpc_toolkit::hyper::{Body, Error as HyperError, Request, Response}; use rpc_toolkit::yajrc::{GenericRpcMethod, RpcError, RpcResponse}; use serde::{Deserialize, Serialize}; use serde_json::Value; +use tokio::sync::{broadcast, oneshot}; use tokio::task::JoinError; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::WebSocketStream; @@ -54,7 +55,7 @@ async fn ws_handler< () }); - let has_valid_session = loop { + let (has_valid_session, token) = loop { if let Some(Message::Text(cookie)) = stream .next() .await @@ -90,23 +91,43 @@ async fn ws_handler< .with_kind(crate::ErrorKind::Network)?; return Ok(()); } - Ok(has_validation) => break has_validation, + Ok(has_validation) => break (has_validation, authenticated_session), } } }; + let kill = subscribe_to_session_kill(&ctx, token).await; send_dump(has_valid_session, &mut stream, dump).await?; - deal_with_messages(has_valid_session, sub, stream).await?; + deal_with_messages(has_valid_session, kill, sub, stream).await?; Ok(()) } +async fn subscribe_to_session_kill( + ctx: &RpcContext, + token: HashSessionToken, +) -> oneshot::Receiver<()> { + let (send, recv) = oneshot::channel(); + let mut guard = ctx.open_authed_websockets.lock().await; + if !guard.contains_key(&token) { + guard.insert(token, vec![send]); + } else { + guard.get_mut(&token).unwrap().push(send); + } + recv +} + async fn deal_with_messages( _has_valid_authentication: HasValidSession, - mut sub: tokio::sync::broadcast::Receiver>, + mut kill: oneshot::Receiver<()>, + mut sub: broadcast::Receiver>, mut stream: WebSocketStream, ) -> Result<(), Error> { loop { futures::select! { + _ = (&mut kill).fuse() => { + tracing::info!("Closing WebSocket: Reason: Session Terminated"); + return Ok(()) + } new_rev = sub.recv().fuse() => { let rev = new_rev.with_kind(crate::ErrorKind::Database)?; stream @@ -132,6 +153,8 @@ async fn deal_with_messages( Some(Message::Close(frame)) => { if let Some(reason) = frame.as_ref() { tracing::info!("Closing WebSocket: Reason: {} {}", reason.code, reason.reason); + } else { + tracing::info!("Closing WebSocket: Reason: Unknown"); } stream .send(Message::Close(frame)) diff --git a/appmgr/src/middleware/auth.rs b/appmgr/src/middleware/auth.rs index 720507045..e04ca2c1d 100644 --- a/appmgr/src/middleware/auth.rs +++ b/appmgr/src/middleware/auth.rs @@ -1,3 +1,5 @@ +use std::borrow::Borrow; + use basic_cookies::Cookie; use color_eyre::eyre::eyre; use digest::Digest; @@ -29,17 +31,28 @@ impl HasLoggedOutSessions { logged_out_sessions: impl IntoIterator, ctx: &RpcContext, ) -> Result { + let sessions = logged_out_sessions + .into_iter() + .by_ref() + .map(|x| x.as_logout_session_id()) + .collect::>(); sqlx::query(&format!( "UPDATE session SET logged_out = CURRENT_TIMESTAMP WHERE id IN ('{}')", - logged_out_sessions - .into_iter() - .by_ref() - .map(|x| x.as_logout_session_id()) - .collect::>() - .join("','") + sessions.join("','") )) .execute(&mut ctx.secret_store.acquire().await?) .await?; + for session in sessions { + for socket in ctx + .open_authed_websockets + .lock() + .await + .remove(&session) + .unwrap_or_default() + { + let _ = socket.send(()); + } + } Ok(Self(())) } } @@ -142,6 +155,27 @@ impl AsLogoutSessionId for HashSessionToken { self.hashed } } +impl PartialEq for HashSessionToken { + fn eq(&self, other: &Self) -> bool { + self.hashed == other.hashed + } +} +impl Eq for HashSessionToken {} +impl PartialOrd for HashSessionToken { + fn partial_cmp(&self, other: &Self) -> Option { + self.hashed.partial_cmp(&other.hashed) + } +} +impl Ord for HashSessionToken { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.hashed.cmp(&other.hashed) + } +} +impl Borrow for HashSessionToken { + fn borrow(&self) -> &String { + &self.hashed + } +} pub fn auth(ctx: RpcContext) -> DynMiddleware { Box::new(