mirror of
https://github.com/Start9Labs/patch-db.git
synced 2026-03-26 10:21:53 +00:00
enforce correct api use
This commit is contained in:
committed by
Aiden McClelland
parent
d293e9a1f9
commit
0860275869
@@ -1,6 +1,6 @@
|
|||||||
use std::collections::{BTreeMap, VecDeque};
|
use std::collections::{BTreeMap, VecDeque};
|
||||||
|
|
||||||
use imbl::{ordset, OrdSet};
|
use imbl::{ordmap, ordset, OrdMap, OrdSet};
|
||||||
use json_ptr::{JsonPointer, SegList};
|
use json_ptr::{JsonPointer, SegList};
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use proptest::prelude::*;
|
use proptest::prelude::*;
|
||||||
@@ -27,6 +27,8 @@ impl Locker {
|
|||||||
let mut cancellations = vec![dummy_recv];
|
let mut cancellations = vec![dummy_recv];
|
||||||
|
|
||||||
let mut request_queue = VecDeque::<(Request, OrdSet<HandleId>)>::new();
|
let mut request_queue = VecDeque::<(Request, OrdSet<HandleId>)>::new();
|
||||||
|
let mut lock_order_enforcer = LockOrderEnforcer::new();
|
||||||
|
|
||||||
while let Some(action) =
|
while let Some(action) =
|
||||||
get_action(&mut new_requests, &mut locks_on_lease, &mut cancellations).await
|
get_action(&mut new_requests, &mut locks_on_lease, &mut cancellations).await
|
||||||
{
|
{
|
||||||
@@ -110,21 +112,26 @@ impl Locker {
|
|||||||
Action::HandleRequest(mut req) => {
|
Action::HandleRequest(mut req) => {
|
||||||
#[cfg(feature = "tracing")]
|
#[cfg(feature = "tracing")]
|
||||||
tracing::debug!("New lock request");
|
tracing::debug!("New lock request");
|
||||||
cancellations.extend(req.cancel.take());
|
if let Err(e) = lock_order_enforcer.try_insert(&req.lock_info) {
|
||||||
let hot_seat = request_queue.pop_front();
|
req.reject(e)
|
||||||
process_new_req(
|
} else {
|
||||||
hot_seat.as_ref(),
|
cancellations.extend(req.cancel.take());
|
||||||
req,
|
let hot_seat = request_queue.pop_front();
|
||||||
&mut trie,
|
process_new_req(
|
||||||
&mut locks_on_lease,
|
hot_seat.as_ref(),
|
||||||
&mut request_queue,
|
req,
|
||||||
);
|
&mut trie,
|
||||||
if let Some(hot_seat) = hot_seat {
|
&mut locks_on_lease,
|
||||||
request_queue.push_front(hot_seat);
|
&mut request_queue,
|
||||||
|
);
|
||||||
|
if let Some(hot_seat) = hot_seat {
|
||||||
|
request_queue.push_front(hot_seat);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Action::HandleRelease(lock_info) => {
|
Action::HandleRelease(lock_info) => {
|
||||||
// release actual lock
|
// release actual lock
|
||||||
|
lock_order_enforcer.remove(&lock_info);
|
||||||
trie.unlock(&lock_info);
|
trie.unlock(&lock_info);
|
||||||
#[cfg(feature = "tracing")]
|
#[cfg(feature = "tracing")]
|
||||||
{
|
{
|
||||||
@@ -176,7 +183,7 @@ impl Locker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Action::HandleCancel(lock_info) => {
|
Action::HandleCancel(lock_info) => {
|
||||||
// trie.handle_cancel(lock_info, &mut locks_on_lease)
|
lock_order_enforcer.remove(&lock_info);
|
||||||
let entry = request_queue
|
let entry = request_queue
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
@@ -218,7 +225,7 @@ impl Locker {
|
|||||||
struct CancelGuard {
|
struct CancelGuard {
|
||||||
lock_info: Option<LockInfo>,
|
lock_info: Option<LockInfo>,
|
||||||
channel: Option<oneshot::Sender<LockInfo>>,
|
channel: Option<oneshot::Sender<LockInfo>>,
|
||||||
recv: oneshot::Receiver<Guard>,
|
recv: oneshot::Receiver<Result<Guard, LockError>>,
|
||||||
}
|
}
|
||||||
impl Drop for CancelGuard {
|
impl Drop for CancelGuard {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
@@ -251,7 +258,7 @@ impl Locker {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
let res = (&mut cancel_guard.recv).await.unwrap();
|
let res = (&mut cancel_guard.recv).await.unwrap();
|
||||||
cancel_guard.channel.take();
|
cancel_guard.channel.take();
|
||||||
Ok(res)
|
res
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -297,6 +304,115 @@ async fn get_action(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct LockOrderEnforcer {
|
||||||
|
locks_held: OrdMap<HandleId, OrdMap<(JsonPointer, LockType), usize>>,
|
||||||
|
}
|
||||||
|
impl LockOrderEnforcer {
|
||||||
|
fn new() -> Self {
|
||||||
|
LockOrderEnforcer {
|
||||||
|
locks_held: ordmap! {},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// locks must be acquired in lexicographic order for the pointer, and reverse order for type
|
||||||
|
fn validate(&self, req: &LockInfo) -> Result<(), LockError> {
|
||||||
|
match self.locks_held.get(&req.handle_id) {
|
||||||
|
None => Ok(()),
|
||||||
|
Some(m) => {
|
||||||
|
// quick accept
|
||||||
|
for (ptr, ty) in m.keys() {
|
||||||
|
let tmp = LockInfo {
|
||||||
|
ptr: ptr.clone(),
|
||||||
|
ty: *ty,
|
||||||
|
handle_id: req.handle_id.clone(),
|
||||||
|
};
|
||||||
|
if tmp.implicitly_grants(req) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let err = m.keys().find_map(|(ptr, ty)| {
|
||||||
|
match ptr.cmp(&req.ptr) {
|
||||||
|
std::cmp::Ordering::Less => None, // this is OK
|
||||||
|
std::cmp::Ordering::Equal => {
|
||||||
|
if req.ty > *ty {
|
||||||
|
Some(LockError::LockTypeEscalation {
|
||||||
|
session: req.handle_id.clone(),
|
||||||
|
ptr: ptr.clone(),
|
||||||
|
first: *ty,
|
||||||
|
second: req.ty,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::cmp::Ordering::Greater => Some(if ptr.starts_with(&req.ptr) {
|
||||||
|
LockError::LockTaxonomyEscalation {
|
||||||
|
session: req.handle_id.clone(),
|
||||||
|
first: ptr.clone(),
|
||||||
|
second: req.ptr.clone(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
LockError::NonCanonicalOrdering {
|
||||||
|
session: req.handle_id.clone(),
|
||||||
|
first: ptr.clone(),
|
||||||
|
second: req.ptr.clone(),
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
err.map_or(Ok(()), Err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn try_insert(&mut self, req: &LockInfo) -> Result<(), LockError> {
|
||||||
|
self.validate(req)?;
|
||||||
|
match self.locks_held.get_mut(&req.handle_id) {
|
||||||
|
None => {
|
||||||
|
self.locks_held.insert(
|
||||||
|
req.handle_id.clone(),
|
||||||
|
ordmap![(req.ptr.clone(), req.ty) => 1],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Some(locks) => {
|
||||||
|
let k = (req.ptr.clone(), req.ty);
|
||||||
|
match locks.get_mut(&k) {
|
||||||
|
None => {
|
||||||
|
locks.insert(k, 1);
|
||||||
|
}
|
||||||
|
Some(n) => {
|
||||||
|
*n += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
fn remove(&mut self, req: &LockInfo) {
|
||||||
|
match self.locks_held.remove_with_key(&req.handle_id) {
|
||||||
|
None => {
|
||||||
|
#[cfg(feature = "tracing")]
|
||||||
|
tracing::warn!("Invalid removal from session manager: {}", req);
|
||||||
|
}
|
||||||
|
Some((hdl, mut locks)) => {
|
||||||
|
let k = (req.ptr.clone(), req.ty);
|
||||||
|
match locks.remove_with_key(&k) {
|
||||||
|
None => {
|
||||||
|
#[cfg(feature = "tracing")]
|
||||||
|
tracing::warn!("Invalid removal from session manager: {}", req);
|
||||||
|
}
|
||||||
|
Some((k, n)) => {
|
||||||
|
if n - 1 > 0 {
|
||||||
|
locks.insert(k, n - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !locks.is_empty() {
|
||||||
|
self.locks_held.insert(hdl, locks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
struct Trie {
|
struct Trie {
|
||||||
state: LockState,
|
state: LockState,
|
||||||
@@ -891,6 +1007,24 @@ impl LockInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fn implicitly_grants(&self, other: &LockInfo) -> bool {
|
||||||
|
self.handle_id == other.handle_id
|
||||||
|
&& match self.ty {
|
||||||
|
LockType::Exist => other.ty == LockType::Exist && self.ptr.starts_with(&other.ptr),
|
||||||
|
LockType::Read => {
|
||||||
|
// E's in the ancestry
|
||||||
|
other.ty == LockType::Exist && self.ptr.starts_with(&other.ptr)
|
||||||
|
// nonexclusive locks in the subtree
|
||||||
|
|| other.ty != LockType::Write && other.ptr.starts_with(&self.ptr)
|
||||||
|
}
|
||||||
|
LockType::Write => {
|
||||||
|
// E's in the ancestry
|
||||||
|
other.ty == LockType::Exist && self.ptr.starts_with(&other.ptr)
|
||||||
|
// anything in the subtree
|
||||||
|
|| other.ptr.starts_with(&self.ptr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||||
@@ -944,20 +1078,26 @@ pub enum LockError {
|
|||||||
struct Request {
|
struct Request {
|
||||||
lock_info: LockInfo,
|
lock_info: LockInfo,
|
||||||
cancel: Option<oneshot::Receiver<LockInfo>>,
|
cancel: Option<oneshot::Receiver<LockInfo>>,
|
||||||
completion: oneshot::Sender<Guard>,
|
completion: oneshot::Sender<Result<Guard, LockError>>,
|
||||||
}
|
}
|
||||||
impl Request {
|
impl Request {
|
||||||
fn complete(self) -> oneshot::Receiver<LockInfo> {
|
fn complete(self) -> oneshot::Receiver<LockInfo> {
|
||||||
let (sender, receiver) = oneshot::channel();
|
let (sender, receiver) = oneshot::channel();
|
||||||
if let Err(_) = self.completion.send(Guard {
|
if let Err(_) = self.completion.send(Ok(Guard {
|
||||||
lock_info: self.lock_info,
|
lock_info: self.lock_info,
|
||||||
sender: Some(sender),
|
sender: Some(sender),
|
||||||
}) {
|
})) {
|
||||||
#[cfg(feature = "tracing")]
|
#[cfg(feature = "tracing")]
|
||||||
tracing::warn!("Completion sent to closed channel.")
|
tracing::warn!("Completion sent to closed channel.")
|
||||||
}
|
}
|
||||||
receiver
|
receiver
|
||||||
}
|
}
|
||||||
|
fn reject(self, err: LockError) {
|
||||||
|
if let Err(_) = self.completion.send(Err(err)) {
|
||||||
|
#[cfg(feature = "tracing")]
|
||||||
|
tracing::warn!("Rejection sent to closed channel.")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|||||||
Reference in New Issue
Block a user