mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-03-31 04:23:40 +00:00
closes #659
This commit is contained in:
committed by
Aiden McClelland
parent
d55586755d
commit
f922a6f08c
@@ -1,4 +1,4 @@
|
|||||||
use std::collections::VecDeque;
|
use std::collections::{BTreeMap, VecDeque};
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
@@ -16,13 +16,13 @@ use serde::Deserialize;
|
|||||||
use sqlx::sqlite::SqliteConnectOptions;
|
use sqlx::sqlite::SqliteConnectOptions;
|
||||||
use sqlx::SqlitePool;
|
use sqlx::SqlitePool;
|
||||||
use tokio::fs::File;
|
use tokio::fs::File;
|
||||||
use tokio::sync::broadcast::Sender;
|
use tokio::sync::{broadcast, oneshot, Mutex, RwLock};
|
||||||
use tokio::sync::RwLock;
|
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
|
||||||
use crate::db::model::Database;
|
use crate::db::model::Database;
|
||||||
use crate::hostname::{get_hostname, get_id};
|
use crate::hostname::{get_hostname, get_id};
|
||||||
use crate::manager::ManagerMap;
|
use crate::manager::ManagerMap;
|
||||||
|
use crate::middleware::auth::HashSessionToken;
|
||||||
use crate::net::tor::os_key;
|
use crate::net::tor::os_key;
|
||||||
use crate::net::NetController;
|
use crate::net::NetController;
|
||||||
use crate::notifications::NotificationManager;
|
use crate::notifications::NotificationManager;
|
||||||
@@ -115,12 +115,13 @@ pub struct RpcContextSeed {
|
|||||||
pub revision_cache_size: usize,
|
pub revision_cache_size: usize,
|
||||||
pub revision_cache: RwLock<VecDeque<Arc<Revision>>>,
|
pub revision_cache: RwLock<VecDeque<Arc<Revision>>>,
|
||||||
pub metrics_cache: RwLock<Option<crate::system::Metrics>>,
|
pub metrics_cache: RwLock<Option<crate::system::Metrics>>,
|
||||||
pub shutdown: Sender<Option<Shutdown>>,
|
pub shutdown: broadcast::Sender<Option<Shutdown>>,
|
||||||
pub websocket_count: AtomicUsize,
|
pub websocket_count: AtomicUsize,
|
||||||
pub logger: EmbassyLogger,
|
pub logger: EmbassyLogger,
|
||||||
pub log_epoch: Arc<AtomicU64>,
|
pub log_epoch: Arc<AtomicU64>,
|
||||||
pub tor_socks: SocketAddr,
|
pub tor_socks: SocketAddr,
|
||||||
pub notification_manager: NotificationManager,
|
pub notification_manager: NotificationManager,
|
||||||
|
pub open_authed_websockets: Mutex<BTreeMap<HashSessionToken, Vec<oneshot::Sender<()>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@@ -185,6 +186,7 @@ impl RpcContext {
|
|||||||
9050,
|
9050,
|
||||||
))),
|
))),
|
||||||
notification_manager,
|
notification_manager,
|
||||||
|
open_authed_websockets: Mutex::new(BTreeMap::new()),
|
||||||
});
|
});
|
||||||
let metrics_seed = seed.clone();
|
let metrics_seed = seed.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ use rpc_toolkit::hyper::{Body, Error as HyperError, Request, Response};
|
|||||||
use rpc_toolkit::yajrc::{GenericRpcMethod, RpcError, RpcResponse};
|
use rpc_toolkit::yajrc::{GenericRpcMethod, RpcError, RpcResponse};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
use tokio::sync::{broadcast, oneshot};
|
||||||
use tokio::task::JoinError;
|
use tokio::task::JoinError;
|
||||||
use tokio_tungstenite::tungstenite::Message;
|
use tokio_tungstenite::tungstenite::Message;
|
||||||
use tokio_tungstenite::WebSocketStream;
|
use tokio_tungstenite::WebSocketStream;
|
||||||
@@ -54,7 +55,7 @@ async fn ws_handler<
|
|||||||
()
|
()
|
||||||
});
|
});
|
||||||
|
|
||||||
let has_valid_session = loop {
|
let (has_valid_session, token) = loop {
|
||||||
if let Some(Message::Text(cookie)) = stream
|
if let Some(Message::Text(cookie)) = stream
|
||||||
.next()
|
.next()
|
||||||
.await
|
.await
|
||||||
@@ -90,23 +91,43 @@ async fn ws_handler<
|
|||||||
.with_kind(crate::ErrorKind::Network)?;
|
.with_kind(crate::ErrorKind::Network)?;
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
Ok(has_validation) => break has_validation,
|
Ok(has_validation) => break (has_validation, authenticated_session),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let kill = subscribe_to_session_kill(&ctx, token).await;
|
||||||
send_dump(has_valid_session, &mut stream, dump).await?;
|
send_dump(has_valid_session, &mut stream, dump).await?;
|
||||||
|
|
||||||
deal_with_messages(has_valid_session, sub, stream).await?;
|
deal_with_messages(has_valid_session, kill, sub, stream).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn subscribe_to_session_kill(
|
||||||
|
ctx: &RpcContext,
|
||||||
|
token: HashSessionToken,
|
||||||
|
) -> oneshot::Receiver<()> {
|
||||||
|
let (send, recv) = oneshot::channel();
|
||||||
|
let mut guard = ctx.open_authed_websockets.lock().await;
|
||||||
|
if !guard.contains_key(&token) {
|
||||||
|
guard.insert(token, vec![send]);
|
||||||
|
} else {
|
||||||
|
guard.get_mut(&token).unwrap().push(send);
|
||||||
|
}
|
||||||
|
recv
|
||||||
|
}
|
||||||
|
|
||||||
async fn deal_with_messages(
|
async fn deal_with_messages(
|
||||||
_has_valid_authentication: HasValidSession,
|
_has_valid_authentication: HasValidSession,
|
||||||
mut sub: tokio::sync::broadcast::Receiver<Arc<Revision>>,
|
mut kill: oneshot::Receiver<()>,
|
||||||
|
mut sub: broadcast::Receiver<Arc<Revision>>,
|
||||||
mut stream: WebSocketStream<Upgraded>,
|
mut stream: WebSocketStream<Upgraded>,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
loop {
|
loop {
|
||||||
futures::select! {
|
futures::select! {
|
||||||
|
_ = (&mut kill).fuse() => {
|
||||||
|
tracing::info!("Closing WebSocket: Reason: Session Terminated");
|
||||||
|
return Ok(())
|
||||||
|
}
|
||||||
new_rev = sub.recv().fuse() => {
|
new_rev = sub.recv().fuse() => {
|
||||||
let rev = new_rev.with_kind(crate::ErrorKind::Database)?;
|
let rev = new_rev.with_kind(crate::ErrorKind::Database)?;
|
||||||
stream
|
stream
|
||||||
@@ -132,6 +153,8 @@ async fn deal_with_messages(
|
|||||||
Some(Message::Close(frame)) => {
|
Some(Message::Close(frame)) => {
|
||||||
if let Some(reason) = frame.as_ref() {
|
if let Some(reason) = frame.as_ref() {
|
||||||
tracing::info!("Closing WebSocket: Reason: {} {}", reason.code, reason.reason);
|
tracing::info!("Closing WebSocket: Reason: {} {}", reason.code, reason.reason);
|
||||||
|
} else {
|
||||||
|
tracing::info!("Closing WebSocket: Reason: Unknown");
|
||||||
}
|
}
|
||||||
stream
|
stream
|
||||||
.send(Message::Close(frame))
|
.send(Message::Close(frame))
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
use std::borrow::Borrow;
|
||||||
|
|
||||||
use basic_cookies::Cookie;
|
use basic_cookies::Cookie;
|
||||||
use color_eyre::eyre::eyre;
|
use color_eyre::eyre::eyre;
|
||||||
use digest::Digest;
|
use digest::Digest;
|
||||||
@@ -29,17 +31,28 @@ impl HasLoggedOutSessions {
|
|||||||
logged_out_sessions: impl IntoIterator<Item = impl AsLogoutSessionId>,
|
logged_out_sessions: impl IntoIterator<Item = impl AsLogoutSessionId>,
|
||||||
ctx: &RpcContext,
|
ctx: &RpcContext,
|
||||||
) -> Result<Self, Error> {
|
) -> Result<Self, Error> {
|
||||||
|
let sessions = logged_out_sessions
|
||||||
|
.into_iter()
|
||||||
|
.by_ref()
|
||||||
|
.map(|x| x.as_logout_session_id())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
sqlx::query(&format!(
|
sqlx::query(&format!(
|
||||||
"UPDATE session SET logged_out = CURRENT_TIMESTAMP WHERE id IN ('{}')",
|
"UPDATE session SET logged_out = CURRENT_TIMESTAMP WHERE id IN ('{}')",
|
||||||
logged_out_sessions
|
sessions.join("','")
|
||||||
.into_iter()
|
|
||||||
.by_ref()
|
|
||||||
.map(|x| x.as_logout_session_id())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("','")
|
|
||||||
))
|
))
|
||||||
.execute(&mut ctx.secret_store.acquire().await?)
|
.execute(&mut ctx.secret_store.acquire().await?)
|
||||||
.await?;
|
.await?;
|
||||||
|
for session in sessions {
|
||||||
|
for socket in ctx
|
||||||
|
.open_authed_websockets
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.remove(&session)
|
||||||
|
.unwrap_or_default()
|
||||||
|
{
|
||||||
|
let _ = socket.send(());
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(Self(()))
|
Ok(Self(()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -142,6 +155,27 @@ impl AsLogoutSessionId for HashSessionToken {
|
|||||||
self.hashed
|
self.hashed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
impl PartialEq for HashSessionToken {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.hashed == other.hashed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Eq for HashSessionToken {}
|
||||||
|
impl PartialOrd for HashSessionToken {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||||
|
self.hashed.partial_cmp(&other.hashed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Ord for HashSessionToken {
|
||||||
|
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||||
|
self.hashed.cmp(&other.hashed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Borrow<String> for HashSessionToken {
|
||||||
|
fn borrow(&self) -> &String {
|
||||||
|
&self.hashed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn auth<M: Metadata>(ctx: RpcContext) -> DynMiddleware<M> {
|
pub fn auth<M: Metadata>(ctx: RpcContext) -> DynMiddleware<M> {
|
||||||
Box::new(
|
Box::new(
|
||||||
|
|||||||
Reference in New Issue
Block a user