cancel safety

This commit is contained in:
Aiden McClelland
2021-10-21 16:31:07 -06:00
committed by Aiden McClelland
parent 2205eb34e6
commit a724d6ec6b
6 changed files with 162 additions and 47 deletions

View File

@@ -12,7 +12,8 @@ version = "0.1.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features] [features]
debug = ["log"] debug = ["tracing"]
trace = ["debug", "tracing-error"]
[dependencies] [dependencies]
async-trait = "0.1.42" async-trait = "0.1.42"
@@ -21,8 +22,9 @@ futures = "0.3.8"
json-patch = { path = "../json-patch" } json-patch = { path = "../json-patch" }
json-ptr = { path = "../json-ptr" } json-ptr = { path = "../json-ptr" }
lazy_static = "1.4.0" lazy_static = "1.4.0"
log = { version = "*", optional = true } tracing = { version = "0.1.29", optional = true }
nix = "0.22.1" tracing-error = { version = "0.1.2", optional = true }
nix = "0.23.0"
patch-db-macro = { path = "../patch-db-macro" } patch-db-macro = { path = "../patch-db-macro" }
serde = { version = "1.0.118", features = ["rc"] } serde = { version = "1.0.118", features = ["rc"] }
serde_cbor = { path = "../cbor" } serde_cbor = { path = "../cbor" }

View File

@@ -5,14 +5,25 @@ use json_ptr::{JsonPointer, SegList};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use tokio::sync::RwLockWriteGuard;
use tokio::sync::{broadcast::Receiver, RwLock, RwLockReadGuard}; use tokio::sync::{broadcast::Receiver, RwLock, RwLockReadGuard};
use crate::locker::LockType; use crate::locker::LockType;
use crate::{locker::Guard, Locker, PatchDb, Revision, Store, Transaction}; use crate::{locker::Guard, Locker, PatchDb, Revision, Store, Transaction};
use crate::{patch::DiffPatch, Error}; use crate::{patch::DiffPatch, Error};
#[derive(Debug, Clone, PartialEq, Eq, Default)] #[derive(Debug, Clone, Default)]
pub struct HandleId(pub(crate) u64); pub struct HandleId {
pub(crate) id: u64,
#[cfg(feature = "trace")]
pub(crate) trace: Option<Arc<tracing_error::SpanTrace>>,
}
impl PartialEq for HandleId {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for HandleId {}
#[async_trait] #[async_trait]
pub trait DbHandle: Send + Sync { pub trait DbHandle: Send + Sync {
@@ -42,7 +53,11 @@ pub trait DbHandle: Send + Sync {
ptr: &JsonPointer<S, V>, ptr: &JsonPointer<S, V>,
value: &Value, value: &Value,
) -> Result<Option<Arc<Revision>>, Error>; ) -> Result<Option<Arc<Revision>>, Error>;
async fn apply(&mut self, patch: DiffPatch) -> Result<Option<Arc<Revision>>, Error>; async fn apply(
&mut self,
patch: DiffPatch,
store_write_lock: Option<RwLockWriteGuard<'_, Store>>,
) -> Result<Option<Arc<Revision>>, Error>;
async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> (); async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> ();
async fn get< async fn get<
T: for<'de> Deserialize<'de>, T: for<'de> Deserialize<'de>,
@@ -122,8 +137,12 @@ impl<Handle: DbHandle + ?Sized> DbHandle for &mut Handle {
) -> Result<Option<Arc<Revision>>, Error> { ) -> Result<Option<Arc<Revision>>, Error> {
(*self).put_value(ptr, value).await (*self).put_value(ptr, value).await
} }
async fn apply(&mut self, patch: DiffPatch) -> Result<Option<Arc<Revision>>, Error> { async fn apply(
(*self).apply(patch).await &mut self,
patch: DiffPatch,
store_write_lock: Option<RwLockWriteGuard<'_, Store>>,
) -> Result<Option<Arc<Revision>>, Error> {
(*self).apply(patch, store_write_lock).await
} }
async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) { async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) {
(*self).lock(ptr, lock_type).await (*self).lock(ptr, lock_type).await
@@ -230,8 +249,12 @@ impl DbHandle for PatchDbHandle {
) -> Result<Option<Arc<Revision>>, Error> { ) -> Result<Option<Arc<Revision>>, Error> {
self.db.put(ptr, value, None).await self.db.put(ptr, value, None).await
} }
async fn apply(&mut self, patch: DiffPatch) -> Result<Option<Arc<Revision>>, Error> { async fn apply(
self.db.apply(patch, None, None).await &mut self,
patch: DiffPatch,
store_write_lock: Option<RwLockWriteGuard<'_, Store>>,
) -> Result<Option<Arc<Revision>>, Error> {
self.db.apply(patch, None, store_write_lock).await
} }
async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) { async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) {
self.locks self.locks

View File

@@ -21,35 +21,70 @@ impl Locker {
// instead we want it to block forever by adding a channel that will never recv // instead we want it to block forever by adding a channel that will never recv
let (_dummy_send, dummy_recv) = oneshot::channel(); let (_dummy_send, dummy_recv) = oneshot::channel();
let mut locks_on_lease = vec![dummy_recv]; let mut locks_on_lease = vec![dummy_recv];
while let Some(action) = get_action(&mut new_requests, &mut locks_on_lease).await { let (_dummy_send, dummy_recv) = oneshot::channel();
#[cfg(feature = "log")] let mut cancellations = vec![dummy_recv];
log::trace!("Locker Action: {:#?}", action); while let Some(action) =
get_action(&mut new_requests, &mut locks_on_lease, &mut cancellations).await
{
#[cfg(feature = "tracing")]
tracing::trace!("Locker Action: {:#?}", action);
match action { match action {
Action::HandleRequest(req) => trie.handle_request(req, &mut locks_on_lease), Action::HandleRequest(mut req) => {
cancellations.extend(req.cancel.take());
trie.handle_request(req, &mut locks_on_lease)
}
Action::HandleRelease(lock_info) => { Action::HandleRelease(lock_info) => {
trie.handle_release(lock_info, &mut locks_on_lease) trie.handle_release(lock_info, &mut locks_on_lease)
} }
Action::HandleCancel(lock_info) => {
trie.handle_cancel(lock_info, &mut locks_on_lease)
} }
#[cfg(feature = "log")] }
log::trace!("Locker Trie: {:#?}", trie); #[cfg(feature = "tracing")]
tracing::trace!("Locker Trie: {:#?}", trie);
} }
}); });
Locker { sender } Locker { sender }
} }
pub async fn lock(&self, handle_id: HandleId, ptr: JsonPointer, lock_type: LockType) -> Guard { pub async fn lock(&self, handle_id: HandleId, ptr: JsonPointer, lock_type: LockType) -> Guard {
let (send, recv) = oneshot::channel(); struct CancelGuard {
self.sender lock_info: Option<LockInfo>,
.send(Request { channel: Option<oneshot::Sender<LockInfo>>,
lock_info: LockInfo { recv: oneshot::Receiver<Guard>,
}
impl Drop for CancelGuard {
fn drop(&mut self) {
if let (Some(lock_info), Some(channel)) =
(self.lock_info.take(), self.channel.take())
{
self.recv.close();
let _ = channel.send(lock_info);
}
}
}
let lock_info = LockInfo {
handle_id, handle_id,
ptr, ptr,
ty: lock_type, ty: lock_type,
segments_handled: 0, segments_handled: 0,
}, };
let (send, recv) = oneshot::channel();
let (cancel_send, cancel_recv) = oneshot::channel();
let mut cancel_guard = CancelGuard {
lock_info: Some(lock_info.clone()),
channel: Some(cancel_send),
recv,
};
self.sender
.send(Request {
lock_info,
cancel: Some(cancel_recv),
completion: send, completion: send,
}) })
.unwrap(); .unwrap();
recv.await.unwrap() let res = (&mut cancel_guard.recv).await.unwrap();
cancel_guard.channel.take();
res
} }
} }
@@ -62,13 +97,15 @@ struct RequestQueue {
enum Action { enum Action {
HandleRequest(Request), HandleRequest(Request),
HandleRelease(LockInfo), HandleRelease(LockInfo),
HandleCancel(LockInfo),
} }
async fn get_action( async fn get_action(
new_requests: &mut RequestQueue, new_requests: &mut RequestQueue,
locks_on_lease: &mut Vec<oneshot::Receiver<LockInfo>>, locks_on_lease: &mut Vec<oneshot::Receiver<LockInfo>>,
cancellations: &mut Vec<oneshot::Receiver<LockInfo>>,
) -> Option<Action> { ) -> Option<Action> {
loop { loop {
if new_requests.closed && locks_on_lease.is_empty() { if new_requests.closed && locks_on_lease.len() == 1 && cancellations.len() == 1 {
return None; return None;
} }
tokio::select! { tokio::select! {
@@ -83,6 +120,12 @@ async fn get_action(
locks_on_lease.swap_remove(idx); locks_on_lease.swap_remove(idx);
return Some(Action::HandleRelease(a.unwrap())) return Some(Action::HandleRelease(a.unwrap()))
} }
(a, idx, _) = futures::future::select_all(cancellations.iter_mut()) => {
cancellations.swap_remove(idx);
if let Ok(a) = a {
return Some(Action::HandleCancel(a))
}
}
} }
} }
} }
@@ -123,6 +166,20 @@ impl Trie {
.handle_release(release, locks_on_lease) .handle_release(release, locks_on_lease)
} }
} }
fn handle_cancel(
&mut self,
lock_info: LockInfo,
locks_on_lease: &mut Vec<oneshot::Receiver<LockInfo>>,
) {
let cancel = self.node.cancel(lock_info);
for req in std::mem::take(&mut self.node.reqs) {
self.handle_request(req, locks_on_lease);
}
if let Some(cancel) = cancel {
self.child_mut(cancel.current_seg())
.handle_cancel(cancel, locks_on_lease)
}
}
} }
#[derive(Debug, Default)] #[derive(Debug, Default)]
@@ -273,9 +330,26 @@ impl Node {
Some(lock_info) Some(lock_info)
} }
} }
fn cancel(&mut self, mut lock_info: LockInfo) -> Option<LockInfo> {
let mut idx = 0;
while idx < self.reqs.len() {
if self.reqs[idx].completion.is_closed() && self.reqs[idx].lock_info == lock_info {
self.reqs.swap_remove(idx);
return None;
} else {
idx += 1;
}
}
if lock_info.ptr.len() == lock_info.segments_handled {
None
} else {
lock_info.segments_handled += 1;
Some(lock_info)
}
}
} }
#[derive(Debug, Default)] #[derive(Debug, Clone, Default, PartialEq)]
struct LockInfo { struct LockInfo {
ptr: JsonPointer, ptr: JsonPointer,
segments_handled: usize, segments_handled: usize,
@@ -313,6 +387,7 @@ impl Default for LockType {
#[derive(Debug)] #[derive(Debug)]
struct Request { struct Request {
lock_info: LockInfo, lock_info: LockInfo,
cancel: Option<oneshot::Receiver<LockInfo>>,
completion: oneshot::Sender<Guard>, completion: oneshot::Sender<Guard>,
} }
impl Request { impl Request {
@@ -328,10 +403,13 @@ impl Request {
fn complete(self, locks_on_lease: &mut Vec<oneshot::Receiver<LockInfo>>) { fn complete(self, locks_on_lease: &mut Vec<oneshot::Receiver<LockInfo>>) {
let (sender, receiver) = oneshot::channel(); let (sender, receiver) = oneshot::channel();
locks_on_lease.push(receiver); locks_on_lease.push(receiver);
let _ = self.completion.send(Guard { if let Err(_) = self.completion.send(Guard {
lock_info: self.lock_info.reset(), lock_info: self.lock_info.reset(),
sender: Some(sender), sender: Some(sender),
}); }) {
#[cfg(feature = "tracing")]
tracing::warn!("Completion sent to closed channel.")
}
} }
} }
@@ -342,10 +420,14 @@ pub struct Guard {
} }
impl Drop for Guard { impl Drop for Guard {
fn drop(&mut self) { fn drop(&mut self) {
let _ = self if let Err(_e) = self
.sender .sender
.take() .take()
.unwrap() .unwrap()
.send(std::mem::take(&mut self.lock_info)); .send(std::mem::take(&mut self.lock_info))
{
#[cfg(feature = "tracing")]
tracing::warn!("Failed to release lock: {:?}", _e)
}
} }
} }

View File

@@ -33,13 +33,14 @@ pub struct ModelDataMut<T: Serialize + for<'de> Deserialize<'de>> {
ptr: JsonPointer, ptr: JsonPointer,
} }
impl<T: Serialize + for<'de> Deserialize<'de>> ModelDataMut<T> { impl<T: Serialize + for<'de> Deserialize<'de>> ModelDataMut<T> {
pub async fn save<Db: DbHandle>(self, db: &mut Db) -> Result<(), Error> { pub async fn save<Db: DbHandle>(&mut self, db: &mut Db) -> Result<(), Error> {
let current = serde_json::to_value(&self.current)?; let current = serde_json::to_value(&self.current)?;
let mut diff = crate::patch::diff(&self.original, &current); let mut diff = crate::patch::diff(&self.original, &current);
let target = db.get_value(&self.ptr, None).await?; let target = db.get_value(&self.ptr, None).await?;
diff.rebase(&crate::patch::diff(&self.original, &target)); diff.rebase(&crate::patch::diff(&self.original, &target));
diff.prepend(&self.ptr); diff.prepend(&self.ptr);
db.apply(diff).await?; db.apply(diff, None).await?;
self.original = current;
Ok(()) Ok(())
} }
} }
@@ -501,11 +502,12 @@ where
pub async fn remove<Db: DbHandle>(&self, db: &mut Db, key: &T::Key) -> Result<(), Error> { pub async fn remove<Db: DbHandle>(&self, db: &mut Db, key: &T::Key) -> Result<(), Error> {
db.lock(self.as_ref().clone(), LockType::Write).await; db.lock(self.as_ref().clone(), LockType::Write).await;
if db.exists(self.clone().idx(key).as_ref(), None).await? { if db.exists(self.clone().idx(key).as_ref(), None).await? {
db.apply(DiffPatch(Patch(vec![PatchOperation::Remove( db.apply(
RemoveOperation { DiffPatch(Patch(vec![PatchOperation::Remove(RemoveOperation {
path: self.as_ref().clone().join_end(key.as_ref()), path: self.as_ref().clone().join_end(key.as_ref()),
}, })])),
)]))) None,
)
.await?; .await?;
} }
Ok(()) Ok(())

View File

@@ -177,8 +177,8 @@ impl Store {
return Ok(None); return Ok(None);
} }
#[cfg(feature = "log")] #[cfg(feature = "tracing")]
log::trace!("Attempting to apply patch: {:?}", patch); tracing::trace!("Attempting to apply patch: {:?}", patch);
self.check_cache_corrupted()?; self.check_cache_corrupted()?;
let patch_bin = serde_cbor::to_vec(&*patch)?; let patch_bin = serde_cbor::to_vec(&*patch)?;
@@ -288,10 +288,13 @@ impl PatchDb {
} }
pub fn handle(&self) -> PatchDbHandle { pub fn handle(&self) -> PatchDbHandle {
PatchDbHandle { PatchDbHandle {
id: HandleId( id: HandleId {
self.handle_id id: self
.handle_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst), .fetch_add(1, std::sync::atomic::Ordering::SeqCst),
), #[cfg(feature = "trace")]
trace: Some(Arc::new(tracing_error::SpanTrace::capture())),
},
db: self.clone(), db: self.clone(),
locks: Vec::new(), locks: Vec::new(),
} }

View File

@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use tokio::sync::broadcast::error::TryRecvError; use tokio::sync::broadcast::error::TryRecvError;
use tokio::sync::broadcast::Receiver; use tokio::sync::broadcast::Receiver;
use tokio::sync::{RwLock, RwLockReadGuard}; use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use crate::handle::HandleId; use crate::handle::HandleId;
use crate::locker::{Guard, LockType, Locker}; use crate::locker::{Guard, LockType, Locker};
@@ -51,10 +51,9 @@ impl Transaction<&mut PatchDbHandle> {
impl<Parent: DbHandle + Send + Sync> Transaction<Parent> { impl<Parent: DbHandle + Send + Sync> Transaction<Parent> {
pub async fn save(mut self) -> Result<(), Error> { pub async fn save(mut self) -> Result<(), Error> {
let store_lock = self.parent.store(); let store_lock = self.parent.store();
let store = store_lock.read().await; let store = store_lock.write().await;
self.rebase()?; self.rebase()?;
self.parent.apply(self.updates).await?; self.parent.apply(self.updates, Some(store)).await?;
drop(store);
Ok(()) Ok(())
} }
} }
@@ -148,8 +147,8 @@ impl<Parent: DbHandle + Send + Sync> DbHandle for Transaction<Parent> {
}; };
let path_updates = self.updates.for_path(ptr); let path_updates = self.updates.for_path(ptr);
if !(path_updates.0).0.is_empty() { if !(path_updates.0).0.is_empty() {
#[cfg(feature = "log")] #[cfg(feature = "tracing")]
log::trace!("applying patch {:?} at path {}", path_updates, ptr); tracing::trace!("Applying patch {:?} at path {}", path_updates, ptr);
json_patch::patch(&mut data, &*path_updates)?; json_patch::patch(&mut data, &*path_updates)?;
} }
@@ -195,7 +194,11 @@ impl<Parent: DbHandle + Send + Sync> DbHandle for Transaction<Parent> {
) -> Result<Option<Arc<Revision>>, Error> { ) -> Result<Option<Arc<Revision>>, Error> {
self.put_value(ptr, &serde_json::to_value(value)?).await self.put_value(ptr, &serde_json::to_value(value)?).await
} }
async fn apply(&mut self, patch: DiffPatch) -> Result<Option<Arc<Revision>>, Error> { async fn apply(
&mut self,
patch: DiffPatch,
_store_write_lock: Option<RwLockWriteGuard<'_, Store>>,
) -> Result<Option<Arc<Revision>>, Error> {
self.updates.append(patch); self.updates.append(patch);
Ok(None) Ok(None)
} }