enforce correct api use

This commit is contained in:
Keagan McClelland
2021-12-20 17:07:13 -07:00
committed by Aiden McClelland
parent d293e9a1f9
commit 0860275869

View File

@@ -1,6 +1,6 @@
use std::collections::{BTreeMap, VecDeque}; use std::collections::{BTreeMap, VecDeque};
use imbl::{ordset, OrdSet}; use imbl::{ordmap, ordset, OrdMap, OrdSet};
use json_ptr::{JsonPointer, SegList}; use json_ptr::{JsonPointer, SegList};
#[cfg(test)] #[cfg(test)]
use proptest::prelude::*; use proptest::prelude::*;
@@ -27,6 +27,8 @@ impl Locker {
let mut cancellations = vec![dummy_recv]; let mut cancellations = vec![dummy_recv];
let mut request_queue = VecDeque::<(Request, OrdSet<HandleId>)>::new(); let mut request_queue = VecDeque::<(Request, OrdSet<HandleId>)>::new();
let mut lock_order_enforcer = LockOrderEnforcer::new();
while let Some(action) = while let Some(action) =
get_action(&mut new_requests, &mut locks_on_lease, &mut cancellations).await get_action(&mut new_requests, &mut locks_on_lease, &mut cancellations).await
{ {
@@ -110,21 +112,26 @@ impl Locker {
Action::HandleRequest(mut req) => { Action::HandleRequest(mut req) => {
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
tracing::debug!("New lock request"); tracing::debug!("New lock request");
cancellations.extend(req.cancel.take()); if let Err(e) = lock_order_enforcer.try_insert(&req.lock_info) {
let hot_seat = request_queue.pop_front(); req.reject(e)
process_new_req( } else {
hot_seat.as_ref(), cancellations.extend(req.cancel.take());
req, let hot_seat = request_queue.pop_front();
&mut trie, process_new_req(
&mut locks_on_lease, hot_seat.as_ref(),
&mut request_queue, req,
); &mut trie,
if let Some(hot_seat) = hot_seat { &mut locks_on_lease,
request_queue.push_front(hot_seat); &mut request_queue,
);
if let Some(hot_seat) = hot_seat {
request_queue.push_front(hot_seat);
}
} }
} }
Action::HandleRelease(lock_info) => { Action::HandleRelease(lock_info) => {
// release actual lock // release actual lock
lock_order_enforcer.remove(&lock_info);
trie.unlock(&lock_info); trie.unlock(&lock_info);
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
{ {
@@ -176,7 +183,7 @@ impl Locker {
} }
} }
Action::HandleCancel(lock_info) => { Action::HandleCancel(lock_info) => {
// trie.handle_cancel(lock_info, &mut locks_on_lease) lock_order_enforcer.remove(&lock_info);
let entry = request_queue let entry = request_queue
.iter() .iter()
.enumerate() .enumerate()
@@ -218,7 +225,7 @@ impl Locker {
struct CancelGuard { struct CancelGuard {
lock_info: Option<LockInfo>, lock_info: Option<LockInfo>,
channel: Option<oneshot::Sender<LockInfo>>, channel: Option<oneshot::Sender<LockInfo>>,
recv: oneshot::Receiver<Guard>, recv: oneshot::Receiver<Result<Guard, LockError>>,
} }
impl Drop for CancelGuard { impl Drop for CancelGuard {
fn drop(&mut self) { fn drop(&mut self) {
@@ -251,7 +258,7 @@ impl Locker {
.unwrap(); .unwrap();
let res = (&mut cancel_guard.recv).await.unwrap(); let res = (&mut cancel_guard.recv).await.unwrap();
cancel_guard.channel.take(); cancel_guard.channel.take();
Ok(res) res
} }
} }
@@ -297,6 +304,115 @@ async fn get_action(
} }
} }
struct LockOrderEnforcer {
locks_held: OrdMap<HandleId, OrdMap<(JsonPointer, LockType), usize>>,
}
impl LockOrderEnforcer {
fn new() -> Self {
LockOrderEnforcer {
locks_held: ordmap! {},
}
}
// locks must be acquired in lexicographic order for the pointer, and reverse order for type
fn validate(&self, req: &LockInfo) -> Result<(), LockError> {
match self.locks_held.get(&req.handle_id) {
None => Ok(()),
Some(m) => {
// quick accept
for (ptr, ty) in m.keys() {
let tmp = LockInfo {
ptr: ptr.clone(),
ty: *ty,
handle_id: req.handle_id.clone(),
};
if tmp.implicitly_grants(req) {
return Ok(());
}
}
let err = m.keys().find_map(|(ptr, ty)| {
match ptr.cmp(&req.ptr) {
std::cmp::Ordering::Less => None, // this is OK
std::cmp::Ordering::Equal => {
if req.ty > *ty {
Some(LockError::LockTypeEscalation {
session: req.handle_id.clone(),
ptr: ptr.clone(),
first: *ty,
second: req.ty,
})
} else {
None
}
}
std::cmp::Ordering::Greater => Some(if ptr.starts_with(&req.ptr) {
LockError::LockTaxonomyEscalation {
session: req.handle_id.clone(),
first: ptr.clone(),
second: req.ptr.clone(),
}
} else {
LockError::NonCanonicalOrdering {
session: req.handle_id.clone(),
first: ptr.clone(),
second: req.ptr.clone(),
}
}),
}
});
err.map_or(Ok(()), Err)
}
}
}
fn try_insert(&mut self, req: &LockInfo) -> Result<(), LockError> {
self.validate(req)?;
match self.locks_held.get_mut(&req.handle_id) {
None => {
self.locks_held.insert(
req.handle_id.clone(),
ordmap![(req.ptr.clone(), req.ty) => 1],
);
}
Some(locks) => {
let k = (req.ptr.clone(), req.ty);
match locks.get_mut(&k) {
None => {
locks.insert(k, 1);
}
Some(n) => {
*n += 1;
}
}
}
}
Ok(())
}
fn remove(&mut self, req: &LockInfo) {
match self.locks_held.remove_with_key(&req.handle_id) {
None => {
#[cfg(feature = "tracing")]
tracing::warn!("Invalid removal from session manager: {}", req);
}
Some((hdl, mut locks)) => {
let k = (req.ptr.clone(), req.ty);
match locks.remove_with_key(&k) {
None => {
#[cfg(feature = "tracing")]
tracing::warn!("Invalid removal from session manager: {}", req);
}
Some((k, n)) => {
if n - 1 > 0 {
locks.insert(k, n - 1);
}
}
}
if !locks.is_empty() {
self.locks_held.insert(hdl, locks);
}
}
}
}
}
#[derive(Debug, Default)] #[derive(Debug, Default)]
struct Trie { struct Trie {
state: LockState, state: LockState,
@@ -891,6 +1007,24 @@ impl LockInfo {
} }
} }
} }
fn implicitly_grants(&self, other: &LockInfo) -> bool {
self.handle_id == other.handle_id
&& match self.ty {
LockType::Exist => other.ty == LockType::Exist && self.ptr.starts_with(&other.ptr),
LockType::Read => {
// E's in the ancestry
other.ty == LockType::Exist && self.ptr.starts_with(&other.ptr)
// nonexclusive locks in the subtree
|| other.ty != LockType::Write && other.ptr.starts_with(&self.ptr)
}
LockType::Write => {
// E's in the ancestry
other.ty == LockType::Exist && self.ptr.starts_with(&other.ptr)
// anything in the subtree
|| other.ptr.starts_with(&self.ptr)
}
}
}
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
@@ -944,20 +1078,26 @@ pub enum LockError {
struct Request { struct Request {
lock_info: LockInfo, lock_info: LockInfo,
cancel: Option<oneshot::Receiver<LockInfo>>, cancel: Option<oneshot::Receiver<LockInfo>>,
completion: oneshot::Sender<Guard>, completion: oneshot::Sender<Result<Guard, LockError>>,
} }
impl Request { impl Request {
fn complete(self) -> oneshot::Receiver<LockInfo> { fn complete(self) -> oneshot::Receiver<LockInfo> {
let (sender, receiver) = oneshot::channel(); let (sender, receiver) = oneshot::channel();
if let Err(_) = self.completion.send(Guard { if let Err(_) = self.completion.send(Ok(Guard {
lock_info: self.lock_info, lock_info: self.lock_info,
sender: Some(sender), sender: Some(sender),
}) { })) {
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
tracing::warn!("Completion sent to closed channel.") tracing::warn!("Completion sent to closed channel.")
} }
receiver receiver
} }
fn reject(self, err: LockError) {
if let Err(_) = self.completion.send(Err(err)) {
#[cfg(feature = "tracing")]
tracing::warn!("Rejection sent to closed channel.")
}
}
} }
#[derive(Debug)] #[derive(Debug)]