This commit is contained in:
Aiden McClelland
2021-11-08 15:56:51 -07:00
committed by Aiden McClelland
parent d55586755d
commit f922a6f08c
3 changed files with 73 additions and 14 deletions

View File

@@ -1,4 +1,4 @@
use std::collections::VecDeque;
use std::collections::{BTreeMap, VecDeque};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use std::ops::Deref;
use std::path::{Path, PathBuf};
@@ -16,13 +16,13 @@ use serde::Deserialize;
use sqlx::sqlite::SqliteConnectOptions;
use sqlx::SqlitePool;
use tokio::fs::File;
use tokio::sync::broadcast::Sender;
use tokio::sync::RwLock;
use tokio::sync::{broadcast, oneshot, Mutex, RwLock};
use tracing::instrument;
use crate::db::model::Database;
use crate::hostname::{get_hostname, get_id};
use crate::manager::ManagerMap;
use crate::middleware::auth::HashSessionToken;
use crate::net::tor::os_key;
use crate::net::NetController;
use crate::notifications::NotificationManager;
@@ -115,12 +115,13 @@ pub struct RpcContextSeed {
pub revision_cache_size: usize,
pub revision_cache: RwLock<VecDeque<Arc<Revision>>>,
pub metrics_cache: RwLock<Option<crate::system::Metrics>>,
pub shutdown: Sender<Option<Shutdown>>,
pub shutdown: broadcast::Sender<Option<Shutdown>>,
pub websocket_count: AtomicUsize,
pub logger: EmbassyLogger,
pub log_epoch: Arc<AtomicU64>,
pub tor_socks: SocketAddr,
pub notification_manager: NotificationManager,
pub open_authed_websockets: Mutex<BTreeMap<HashSessionToken, Vec<oneshot::Sender<()>>>>,
}
#[derive(Clone)]
@@ -185,6 +186,7 @@ impl RpcContext {
9050,
))),
notification_manager,
open_authed_websockets: Mutex::new(BTreeMap::new()),
});
let metrics_seed = seed.clone();
tokio::spawn(async move {

View File

@@ -16,6 +16,7 @@ use rpc_toolkit::hyper::{Body, Error as HyperError, Request, Response};
use rpc_toolkit::yajrc::{GenericRpcMethod, RpcError, RpcResponse};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::{broadcast, oneshot};
use tokio::task::JoinError;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::WebSocketStream;
@@ -54,7 +55,7 @@ async fn ws_handler<
()
});
let has_valid_session = loop {
let (has_valid_session, token) = loop {
if let Some(Message::Text(cookie)) = stream
.next()
.await
@@ -90,23 +91,43 @@ async fn ws_handler<
.with_kind(crate::ErrorKind::Network)?;
return Ok(());
}
Ok(has_validation) => break has_validation,
Ok(has_validation) => break (has_validation, authenticated_session),
}
}
};
let kill = subscribe_to_session_kill(&ctx, token).await;
send_dump(has_valid_session, &mut stream, dump).await?;
deal_with_messages(has_valid_session, sub, stream).await?;
deal_with_messages(has_valid_session, kill, sub, stream).await?;
Ok(())
}
async fn subscribe_to_session_kill(
ctx: &RpcContext,
token: HashSessionToken,
) -> oneshot::Receiver<()> {
let (send, recv) = oneshot::channel();
let mut guard = ctx.open_authed_websockets.lock().await;
if !guard.contains_key(&token) {
guard.insert(token, vec![send]);
} else {
guard.get_mut(&token).unwrap().push(send);
}
recv
}
async fn deal_with_messages(
_has_valid_authentication: HasValidSession,
mut sub: tokio::sync::broadcast::Receiver<Arc<Revision>>,
mut kill: oneshot::Receiver<()>,
mut sub: broadcast::Receiver<Arc<Revision>>,
mut stream: WebSocketStream<Upgraded>,
) -> Result<(), Error> {
loop {
futures::select! {
_ = (&mut kill).fuse() => {
tracing::info!("Closing WebSocket: Reason: Session Terminated");
return Ok(())
}
new_rev = sub.recv().fuse() => {
let rev = new_rev.with_kind(crate::ErrorKind::Database)?;
stream
@@ -132,6 +153,8 @@ async fn deal_with_messages(
Some(Message::Close(frame)) => {
if let Some(reason) = frame.as_ref() {
tracing::info!("Closing WebSocket: Reason: {} {}", reason.code, reason.reason);
} else {
tracing::info!("Closing WebSocket: Reason: Unknown");
}
stream
.send(Message::Close(frame))

View File

@@ -1,3 +1,5 @@
use std::borrow::Borrow;
use basic_cookies::Cookie;
use color_eyre::eyre::eyre;
use digest::Digest;
@@ -29,17 +31,28 @@ impl HasLoggedOutSessions {
logged_out_sessions: impl IntoIterator<Item = impl AsLogoutSessionId>,
ctx: &RpcContext,
) -> Result<Self, Error> {
let sessions = logged_out_sessions
.into_iter()
.by_ref()
.map(|x| x.as_logout_session_id())
.collect::<Vec<_>>();
sqlx::query(&format!(
"UPDATE session SET logged_out = CURRENT_TIMESTAMP WHERE id IN ('{}')",
logged_out_sessions
.into_iter()
.by_ref()
.map(|x| x.as_logout_session_id())
.collect::<Vec<_>>()
.join("','")
sessions.join("','")
))
.execute(&mut ctx.secret_store.acquire().await?)
.await?;
for session in sessions {
for socket in ctx
.open_authed_websockets
.lock()
.await
.remove(&session)
.unwrap_or_default()
{
let _ = socket.send(());
}
}
Ok(Self(()))
}
}
@@ -142,6 +155,27 @@ impl AsLogoutSessionId for HashSessionToken {
self.hashed
}
}
impl PartialEq for HashSessionToken {
fn eq(&self, other: &Self) -> bool {
self.hashed == other.hashed
}
}
impl Eq for HashSessionToken {}
impl PartialOrd for HashSessionToken {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.hashed.partial_cmp(&other.hashed)
}
}
impl Ord for HashSessionToken {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.hashed.cmp(&other.hashed)
}
}
impl Borrow<String> for HashSessionToken {
fn borrow(&self) -> &String {
&self.hashed
}
}
pub fn auth<M: Metadata>(ctx: RpcContext) -> DynMiddleware<M> {
Box::new(