mirror of
https://github.com/Start9Labs/patch-db.git
synced 2026-03-26 02:11:54 +00:00
enforce correct api use
This commit is contained in:
committed by
Aiden McClelland
parent
d293e9a1f9
commit
0860275869
@@ -1,6 +1,6 @@
|
||||
use std::collections::{BTreeMap, VecDeque};
|
||||
|
||||
use imbl::{ordset, OrdSet};
|
||||
use imbl::{ordmap, ordset, OrdMap, OrdSet};
|
||||
use json_ptr::{JsonPointer, SegList};
|
||||
#[cfg(test)]
|
||||
use proptest::prelude::*;
|
||||
@@ -27,6 +27,8 @@ impl Locker {
|
||||
let mut cancellations = vec![dummy_recv];
|
||||
|
||||
let mut request_queue = VecDeque::<(Request, OrdSet<HandleId>)>::new();
|
||||
let mut lock_order_enforcer = LockOrderEnforcer::new();
|
||||
|
||||
while let Some(action) =
|
||||
get_action(&mut new_requests, &mut locks_on_lease, &mut cancellations).await
|
||||
{
|
||||
@@ -110,21 +112,26 @@ impl Locker {
|
||||
Action::HandleRequest(mut req) => {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::debug!("New lock request");
|
||||
cancellations.extend(req.cancel.take());
|
||||
let hot_seat = request_queue.pop_front();
|
||||
process_new_req(
|
||||
hot_seat.as_ref(),
|
||||
req,
|
||||
&mut trie,
|
||||
&mut locks_on_lease,
|
||||
&mut request_queue,
|
||||
);
|
||||
if let Some(hot_seat) = hot_seat {
|
||||
request_queue.push_front(hot_seat);
|
||||
if let Err(e) = lock_order_enforcer.try_insert(&req.lock_info) {
|
||||
req.reject(e)
|
||||
} else {
|
||||
cancellations.extend(req.cancel.take());
|
||||
let hot_seat = request_queue.pop_front();
|
||||
process_new_req(
|
||||
hot_seat.as_ref(),
|
||||
req,
|
||||
&mut trie,
|
||||
&mut locks_on_lease,
|
||||
&mut request_queue,
|
||||
);
|
||||
if let Some(hot_seat) = hot_seat {
|
||||
request_queue.push_front(hot_seat);
|
||||
}
|
||||
}
|
||||
}
|
||||
Action::HandleRelease(lock_info) => {
|
||||
// release actual lock
|
||||
lock_order_enforcer.remove(&lock_info);
|
||||
trie.unlock(&lock_info);
|
||||
#[cfg(feature = "tracing")]
|
||||
{
|
||||
@@ -176,7 +183,7 @@ impl Locker {
|
||||
}
|
||||
}
|
||||
Action::HandleCancel(lock_info) => {
|
||||
// trie.handle_cancel(lock_info, &mut locks_on_lease)
|
||||
lock_order_enforcer.remove(&lock_info);
|
||||
let entry = request_queue
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -218,7 +225,7 @@ impl Locker {
|
||||
struct CancelGuard {
|
||||
lock_info: Option<LockInfo>,
|
||||
channel: Option<oneshot::Sender<LockInfo>>,
|
||||
recv: oneshot::Receiver<Guard>,
|
||||
recv: oneshot::Receiver<Result<Guard, LockError>>,
|
||||
}
|
||||
impl Drop for CancelGuard {
|
||||
fn drop(&mut self) {
|
||||
@@ -251,7 +258,7 @@ impl Locker {
|
||||
.unwrap();
|
||||
let res = (&mut cancel_guard.recv).await.unwrap();
|
||||
cancel_guard.channel.take();
|
||||
Ok(res)
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
@@ -297,6 +304,115 @@ async fn get_action(
|
||||
}
|
||||
}
|
||||
|
||||
struct LockOrderEnforcer {
|
||||
locks_held: OrdMap<HandleId, OrdMap<(JsonPointer, LockType), usize>>,
|
||||
}
|
||||
impl LockOrderEnforcer {
|
||||
fn new() -> Self {
|
||||
LockOrderEnforcer {
|
||||
locks_held: ordmap! {},
|
||||
}
|
||||
}
|
||||
// locks must be acquired in lexicographic order for the pointer, and reverse order for type
|
||||
fn validate(&self, req: &LockInfo) -> Result<(), LockError> {
|
||||
match self.locks_held.get(&req.handle_id) {
|
||||
None => Ok(()),
|
||||
Some(m) => {
|
||||
// quick accept
|
||||
for (ptr, ty) in m.keys() {
|
||||
let tmp = LockInfo {
|
||||
ptr: ptr.clone(),
|
||||
ty: *ty,
|
||||
handle_id: req.handle_id.clone(),
|
||||
};
|
||||
if tmp.implicitly_grants(req) {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
let err = m.keys().find_map(|(ptr, ty)| {
|
||||
match ptr.cmp(&req.ptr) {
|
||||
std::cmp::Ordering::Less => None, // this is OK
|
||||
std::cmp::Ordering::Equal => {
|
||||
if req.ty > *ty {
|
||||
Some(LockError::LockTypeEscalation {
|
||||
session: req.handle_id.clone(),
|
||||
ptr: ptr.clone(),
|
||||
first: *ty,
|
||||
second: req.ty,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
std::cmp::Ordering::Greater => Some(if ptr.starts_with(&req.ptr) {
|
||||
LockError::LockTaxonomyEscalation {
|
||||
session: req.handle_id.clone(),
|
||||
first: ptr.clone(),
|
||||
second: req.ptr.clone(),
|
||||
}
|
||||
} else {
|
||||
LockError::NonCanonicalOrdering {
|
||||
session: req.handle_id.clone(),
|
||||
first: ptr.clone(),
|
||||
second: req.ptr.clone(),
|
||||
}
|
||||
}),
|
||||
}
|
||||
});
|
||||
err.map_or(Ok(()), Err)
|
||||
}
|
||||
}
|
||||
}
|
||||
fn try_insert(&mut self, req: &LockInfo) -> Result<(), LockError> {
|
||||
self.validate(req)?;
|
||||
match self.locks_held.get_mut(&req.handle_id) {
|
||||
None => {
|
||||
self.locks_held.insert(
|
||||
req.handle_id.clone(),
|
||||
ordmap![(req.ptr.clone(), req.ty) => 1],
|
||||
);
|
||||
}
|
||||
Some(locks) => {
|
||||
let k = (req.ptr.clone(), req.ty);
|
||||
match locks.get_mut(&k) {
|
||||
None => {
|
||||
locks.insert(k, 1);
|
||||
}
|
||||
Some(n) => {
|
||||
*n += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
fn remove(&mut self, req: &LockInfo) {
|
||||
match self.locks_held.remove_with_key(&req.handle_id) {
|
||||
None => {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::warn!("Invalid removal from session manager: {}", req);
|
||||
}
|
||||
Some((hdl, mut locks)) => {
|
||||
let k = (req.ptr.clone(), req.ty);
|
||||
match locks.remove_with_key(&k) {
|
||||
None => {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::warn!("Invalid removal from session manager: {}", req);
|
||||
}
|
||||
Some((k, n)) => {
|
||||
if n - 1 > 0 {
|
||||
locks.insert(k, n - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
if !locks.is_empty() {
|
||||
self.locks_held.insert(hdl, locks);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct Trie {
|
||||
state: LockState,
|
||||
@@ -891,6 +1007,24 @@ impl LockInfo {
|
||||
}
|
||||
}
|
||||
}
|
||||
fn implicitly_grants(&self, other: &LockInfo) -> bool {
|
||||
self.handle_id == other.handle_id
|
||||
&& match self.ty {
|
||||
LockType::Exist => other.ty == LockType::Exist && self.ptr.starts_with(&other.ptr),
|
||||
LockType::Read => {
|
||||
// E's in the ancestry
|
||||
other.ty == LockType::Exist && self.ptr.starts_with(&other.ptr)
|
||||
// nonexclusive locks in the subtree
|
||||
|| other.ty != LockType::Write && other.ptr.starts_with(&self.ptr)
|
||||
}
|
||||
LockType::Write => {
|
||||
// E's in the ancestry
|
||||
other.ty == LockType::Exist && self.ptr.starts_with(&other.ptr)
|
||||
// anything in the subtree
|
||||
|| other.ptr.starts_with(&self.ptr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
@@ -944,20 +1078,26 @@ pub enum LockError {
|
||||
struct Request {
|
||||
lock_info: LockInfo,
|
||||
cancel: Option<oneshot::Receiver<LockInfo>>,
|
||||
completion: oneshot::Sender<Guard>,
|
||||
completion: oneshot::Sender<Result<Guard, LockError>>,
|
||||
}
|
||||
impl Request {
|
||||
fn complete(self) -> oneshot::Receiver<LockInfo> {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
if let Err(_) = self.completion.send(Guard {
|
||||
if let Err(_) = self.completion.send(Ok(Guard {
|
||||
lock_info: self.lock_info,
|
||||
sender: Some(sender),
|
||||
}) {
|
||||
})) {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::warn!("Completion sent to closed channel.")
|
||||
}
|
||||
receiver
|
||||
}
|
||||
fn reject(self, err: LockError) {
|
||||
if let Err(_) = self.completion.send(Err(err)) {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::warn!("Rejection sent to closed channel.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
Reference in New Issue
Block a user