diff --git a/json-ptr b/json-ptr index f1d671b..5e973c8 160000 --- a/json-ptr +++ b/json-ptr @@ -1 +1 @@ -Subproject commit f1d671bd5194a99fcfee918b7c568355c2c3459a +Subproject commit 5e973c8ee7ecd37cef944865fdd95cc1e93d4788 diff --git a/patch-db/src/locker.rs b/patch-db/src/locker.rs index 7b61d3d..842296e 100644 --- a/patch-db/src/locker.rs +++ b/patch-db/src/locker.rs @@ -1,6 +1,6 @@ use std::collections::{BTreeMap, VecDeque}; -use imbl::{ordset, OrdSet}; +use imbl::{ordmap, ordset, OrdMap, OrdSet}; use json_ptr::{JsonPointer, SegList}; #[cfg(test)] use proptest::prelude::*; @@ -27,6 +27,8 @@ impl Locker { let mut cancellations = vec![dummy_recv]; let mut request_queue = VecDeque::<(Request, OrdSet)>::new(); + let mut lock_order_enforcer = LockOrderEnforcer::new(); + while let Some(action) = get_action(&mut new_requests, &mut locks_on_lease, &mut cancellations).await { @@ -110,21 +112,26 @@ impl Locker { Action::HandleRequest(mut req) => { #[cfg(feature = "tracing")] tracing::debug!("New lock request"); - cancellations.extend(req.cancel.take()); - let hot_seat = request_queue.pop_front(); - process_new_req( - hot_seat.as_ref(), - req, - &mut trie, - &mut locks_on_lease, - &mut request_queue, - ); - if let Some(hot_seat) = hot_seat { - request_queue.push_front(hot_seat); + if let Err(e) = lock_order_enforcer.try_insert(&req.lock_info) { + req.reject(e) + } else { + cancellations.extend(req.cancel.take()); + let hot_seat = request_queue.pop_front(); + process_new_req( + hot_seat.as_ref(), + req, + &mut trie, + &mut locks_on_lease, + &mut request_queue, + ); + if let Some(hot_seat) = hot_seat { + request_queue.push_front(hot_seat); + } } } Action::HandleRelease(lock_info) => { // release actual lock + lock_order_enforcer.remove(&lock_info); trie.unlock(&lock_info); #[cfg(feature = "tracing")] { @@ -176,7 +183,7 @@ impl Locker { } } Action::HandleCancel(lock_info) => { - // trie.handle_cancel(lock_info, &mut locks_on_lease) + lock_order_enforcer.remove(&lock_info); let entry = request_queue .iter() .enumerate() @@ -218,7 +225,7 @@ impl Locker { struct CancelGuard { lock_info: Option, channel: Option>, - recv: oneshot::Receiver, + recv: oneshot::Receiver>, } impl Drop for CancelGuard { fn drop(&mut self) { @@ -251,7 +258,7 @@ impl Locker { .unwrap(); let res = (&mut cancel_guard.recv).await.unwrap(); cancel_guard.channel.take(); - Ok(res) + res } } @@ -297,6 +304,115 @@ async fn get_action( } } +struct LockOrderEnforcer { + locks_held: OrdMap>, +} +impl LockOrderEnforcer { + fn new() -> Self { + LockOrderEnforcer { + locks_held: ordmap! {}, + } + } + // locks must be acquired in lexicographic order for the pointer, and reverse order for type + fn validate(&self, req: &LockInfo) -> Result<(), LockError> { + match self.locks_held.get(&req.handle_id) { + None => Ok(()), + Some(m) => { + // quick accept + for (ptr, ty) in m.keys() { + let tmp = LockInfo { + ptr: ptr.clone(), + ty: *ty, + handle_id: req.handle_id.clone(), + }; + if tmp.implicitly_grants(req) { + return Ok(()); + } + } + let err = m.keys().find_map(|(ptr, ty)| { + match ptr.cmp(&req.ptr) { + std::cmp::Ordering::Less => None, // this is OK + std::cmp::Ordering::Equal => { + if req.ty > *ty { + Some(LockError::LockTypeEscalation { + session: req.handle_id.clone(), + ptr: ptr.clone(), + first: *ty, + second: req.ty, + }) + } else { + None + } + } + std::cmp::Ordering::Greater => Some(if ptr.starts_with(&req.ptr) { + LockError::LockTaxonomyEscalation { + session: req.handle_id.clone(), + first: ptr.clone(), + second: req.ptr.clone(), + } + } else { + LockError::NonCanonicalOrdering { + session: req.handle_id.clone(), + first: ptr.clone(), + second: req.ptr.clone(), + } + }), + } + }); + err.map_or(Ok(()), Err) + } + } + } + fn try_insert(&mut self, req: &LockInfo) -> Result<(), LockError> { + self.validate(req)?; + match self.locks_held.get_mut(&req.handle_id) { + None => { + self.locks_held.insert( + req.handle_id.clone(), + ordmap![(req.ptr.clone(), req.ty) => 1], + ); + } + Some(locks) => { + let k = (req.ptr.clone(), req.ty); + match locks.get_mut(&k) { + None => { + locks.insert(k, 1); + } + Some(n) => { + *n += 1; + } + } + } + } + Ok(()) + } + fn remove(&mut self, req: &LockInfo) { + match self.locks_held.remove_with_key(&req.handle_id) { + None => { + #[cfg(feature = "tracing")] + tracing::warn!("Invalid removal from session manager: {}", req); + } + Some((hdl, mut locks)) => { + let k = (req.ptr.clone(), req.ty); + match locks.remove_with_key(&k) { + None => { + #[cfg(feature = "tracing")] + tracing::warn!("Invalid removal from session manager: {}", req); + } + Some((k, n)) => { + if n - 1 > 0 { + locks.insert(k, n - 1); + } + } + } + if !locks.is_empty() { + self.locks_held.insert(hdl, locks); + } + } + } + } +} + #[derive(Debug, Default)] struct Trie { state: LockState, @@ -891,6 +1007,24 @@ impl LockInfo { } } } + fn implicitly_grants(&self, other: &LockInfo) -> bool { + self.handle_id == other.handle_id + && match self.ty { + LockType::Exist => other.ty == LockType::Exist && self.ptr.starts_with(&other.ptr), + LockType::Read => { + // E's in the ancestry + other.ty == LockType::Exist && self.ptr.starts_with(&other.ptr) + // nonexclusive locks in the subtree + || other.ty != LockType::Write && other.ptr.starts_with(&self.ptr) + } + LockType::Write => { + // E's in the ancestry + other.ty == LockType::Exist && self.ptr.starts_with(&other.ptr) + // anything in the subtree + || other.ptr.starts_with(&self.ptr) + } + } + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -944,20 +1078,26 @@ pub enum LockError { struct Request { lock_info: LockInfo, cancel: Option>, - completion: oneshot::Sender, + completion: oneshot::Sender>, } impl Request { fn complete(self) -> oneshot::Receiver { let (sender, receiver) = oneshot::channel(); - if let Err(_) = self.completion.send(Guard { + if let Err(_) = self.completion.send(Ok(Guard { lock_info: self.lock_info, sender: Some(sender), - }) { + })) { #[cfg(feature = "tracing")] tracing::warn!("Completion sent to closed channel.") } receiver } + fn reject(self, err: LockError) { + if let Err(_) = self.completion.send(Err(err)) { + #[cfg(feature = "tracing")] + tracing::warn!("Rejection sent to closed channel.") + } + } } #[derive(Debug)]