mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-03-26 10:21:52 +00:00
closes #659
This commit is contained in:
committed by
Aiden McClelland
parent
d55586755d
commit
f922a6f08c
@@ -1,4 +1,4 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::collections::{BTreeMap, VecDeque};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
|
||||
use std::ops::Deref;
|
||||
use std::path::{Path, PathBuf};
|
||||
@@ -16,13 +16,13 @@ use serde::Deserialize;
|
||||
use sqlx::sqlite::SqliteConnectOptions;
|
||||
use sqlx::SqlitePool;
|
||||
use tokio::fs::File;
|
||||
use tokio::sync::broadcast::Sender;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::{broadcast, oneshot, Mutex, RwLock};
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::db::model::Database;
|
||||
use crate::hostname::{get_hostname, get_id};
|
||||
use crate::manager::ManagerMap;
|
||||
use crate::middleware::auth::HashSessionToken;
|
||||
use crate::net::tor::os_key;
|
||||
use crate::net::NetController;
|
||||
use crate::notifications::NotificationManager;
|
||||
@@ -115,12 +115,13 @@ pub struct RpcContextSeed {
|
||||
pub revision_cache_size: usize,
|
||||
pub revision_cache: RwLock<VecDeque<Arc<Revision>>>,
|
||||
pub metrics_cache: RwLock<Option<crate::system::Metrics>>,
|
||||
pub shutdown: Sender<Option<Shutdown>>,
|
||||
pub shutdown: broadcast::Sender<Option<Shutdown>>,
|
||||
pub websocket_count: AtomicUsize,
|
||||
pub logger: EmbassyLogger,
|
||||
pub log_epoch: Arc<AtomicU64>,
|
||||
pub tor_socks: SocketAddr,
|
||||
pub notification_manager: NotificationManager,
|
||||
pub open_authed_websockets: Mutex<BTreeMap<HashSessionToken, Vec<oneshot::Sender<()>>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -185,6 +186,7 @@ impl RpcContext {
|
||||
9050,
|
||||
))),
|
||||
notification_manager,
|
||||
open_authed_websockets: Mutex::new(BTreeMap::new()),
|
||||
});
|
||||
let metrics_seed = seed.clone();
|
||||
tokio::spawn(async move {
|
||||
|
||||
@@ -16,6 +16,7 @@ use rpc_toolkit::hyper::{Body, Error as HyperError, Request, Response};
|
||||
use rpc_toolkit::yajrc::{GenericRpcMethod, RpcError, RpcResponse};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::{broadcast, oneshot};
|
||||
use tokio::task::JoinError;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
@@ -54,7 +55,7 @@ async fn ws_handler<
|
||||
()
|
||||
});
|
||||
|
||||
let has_valid_session = loop {
|
||||
let (has_valid_session, token) = loop {
|
||||
if let Some(Message::Text(cookie)) = stream
|
||||
.next()
|
||||
.await
|
||||
@@ -90,23 +91,43 @@ async fn ws_handler<
|
||||
.with_kind(crate::ErrorKind::Network)?;
|
||||
return Ok(());
|
||||
}
|
||||
Ok(has_validation) => break has_validation,
|
||||
Ok(has_validation) => break (has_validation, authenticated_session),
|
||||
}
|
||||
}
|
||||
};
|
||||
let kill = subscribe_to_session_kill(&ctx, token).await;
|
||||
send_dump(has_valid_session, &mut stream, dump).await?;
|
||||
|
||||
deal_with_messages(has_valid_session, sub, stream).await?;
|
||||
deal_with_messages(has_valid_session, kill, sub, stream).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn subscribe_to_session_kill(
|
||||
ctx: &RpcContext,
|
||||
token: HashSessionToken,
|
||||
) -> oneshot::Receiver<()> {
|
||||
let (send, recv) = oneshot::channel();
|
||||
let mut guard = ctx.open_authed_websockets.lock().await;
|
||||
if !guard.contains_key(&token) {
|
||||
guard.insert(token, vec![send]);
|
||||
} else {
|
||||
guard.get_mut(&token).unwrap().push(send);
|
||||
}
|
||||
recv
|
||||
}
|
||||
|
||||
async fn deal_with_messages(
|
||||
_has_valid_authentication: HasValidSession,
|
||||
mut sub: tokio::sync::broadcast::Receiver<Arc<Revision>>,
|
||||
mut kill: oneshot::Receiver<()>,
|
||||
mut sub: broadcast::Receiver<Arc<Revision>>,
|
||||
mut stream: WebSocketStream<Upgraded>,
|
||||
) -> Result<(), Error> {
|
||||
loop {
|
||||
futures::select! {
|
||||
_ = (&mut kill).fuse() => {
|
||||
tracing::info!("Closing WebSocket: Reason: Session Terminated");
|
||||
return Ok(())
|
||||
}
|
||||
new_rev = sub.recv().fuse() => {
|
||||
let rev = new_rev.with_kind(crate::ErrorKind::Database)?;
|
||||
stream
|
||||
@@ -132,6 +153,8 @@ async fn deal_with_messages(
|
||||
Some(Message::Close(frame)) => {
|
||||
if let Some(reason) = frame.as_ref() {
|
||||
tracing::info!("Closing WebSocket: Reason: {} {}", reason.code, reason.reason);
|
||||
} else {
|
||||
tracing::info!("Closing WebSocket: Reason: Unknown");
|
||||
}
|
||||
stream
|
||||
.send(Message::Close(frame))
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::borrow::Borrow;
|
||||
|
||||
use basic_cookies::Cookie;
|
||||
use color_eyre::eyre::eyre;
|
||||
use digest::Digest;
|
||||
@@ -29,17 +31,28 @@ impl HasLoggedOutSessions {
|
||||
logged_out_sessions: impl IntoIterator<Item = impl AsLogoutSessionId>,
|
||||
ctx: &RpcContext,
|
||||
) -> Result<Self, Error> {
|
||||
let sessions = logged_out_sessions
|
||||
.into_iter()
|
||||
.by_ref()
|
||||
.map(|x| x.as_logout_session_id())
|
||||
.collect::<Vec<_>>();
|
||||
sqlx::query(&format!(
|
||||
"UPDATE session SET logged_out = CURRENT_TIMESTAMP WHERE id IN ('{}')",
|
||||
logged_out_sessions
|
||||
.into_iter()
|
||||
.by_ref()
|
||||
.map(|x| x.as_logout_session_id())
|
||||
.collect::<Vec<_>>()
|
||||
.join("','")
|
||||
sessions.join("','")
|
||||
))
|
||||
.execute(&mut ctx.secret_store.acquire().await?)
|
||||
.await?;
|
||||
for session in sessions {
|
||||
for socket in ctx
|
||||
.open_authed_websockets
|
||||
.lock()
|
||||
.await
|
||||
.remove(&session)
|
||||
.unwrap_or_default()
|
||||
{
|
||||
let _ = socket.send(());
|
||||
}
|
||||
}
|
||||
Ok(Self(()))
|
||||
}
|
||||
}
|
||||
@@ -142,6 +155,27 @@ impl AsLogoutSessionId for HashSessionToken {
|
||||
self.hashed
|
||||
}
|
||||
}
|
||||
impl PartialEq for HashSessionToken {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.hashed == other.hashed
|
||||
}
|
||||
}
|
||||
impl Eq for HashSessionToken {}
|
||||
impl PartialOrd for HashSessionToken {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
self.hashed.partial_cmp(&other.hashed)
|
||||
}
|
||||
}
|
||||
impl Ord for HashSessionToken {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.hashed.cmp(&other.hashed)
|
||||
}
|
||||
}
|
||||
impl Borrow<String> for HashSessionToken {
|
||||
fn borrow(&self) -> &String {
|
||||
&self.hashed
|
||||
}
|
||||
}
|
||||
|
||||
pub fn auth<M: Metadata>(ctx: RpcContext) -> DynMiddleware<M> {
|
||||
Box::new(
|
||||
|
||||
Reference in New Issue
Block a user