From 16ba75225b2470a1b28b0da64578a43e8ad522f7 Mon Sep 17 00:00:00 2001 From: Aiden McClelland Date: Thu, 21 Oct 2021 16:31:07 -0600 Subject: [PATCH] cancel safety --- patch-db/Cargo.toml | 8 ++- patch-db/src/handle.rs | 37 ++++++++--- patch-db/src/locker.rs | 120 ++++++++++++++++++++++++++++++------ patch-db/src/model.rs | 14 +++-- patch-db/src/store.rs | 13 ++-- patch-db/src/transaction.rs | 17 ++--- 6 files changed, 162 insertions(+), 47 deletions(-) diff --git a/patch-db/Cargo.toml b/patch-db/Cargo.toml index 6b658ef..c32522b 100644 --- a/patch-db/Cargo.toml +++ b/patch-db/Cargo.toml @@ -12,7 +12,8 @@ version = "0.1.0" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -debug = ["log"] +debug = ["tracing"] +trace = ["debug", "tracing-error"] [dependencies] async-trait = "0.1.42" @@ -21,8 +22,9 @@ futures = "0.3.8" json-patch = { path = "../json-patch" } json-ptr = { path = "../json-ptr" } lazy_static = "1.4.0" -log = { version = "*", optional = true } -nix = "0.22.1" +tracing = { version = "0.1.29", optional = true } +tracing-error = { version = "0.1.2", optional = true } +nix = "0.23.0" patch-db-macro = { path = "../patch-db-macro" } serde = { version = "1.0.118", features = ["rc"] } serde_cbor = { path = "../cbor" } diff --git a/patch-db/src/handle.rs b/patch-db/src/handle.rs index 5e148db..6ed461d 100644 --- a/patch-db/src/handle.rs +++ b/patch-db/src/handle.rs @@ -5,14 +5,25 @@ use json_ptr::{JsonPointer, SegList}; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::BTreeSet; +use tokio::sync::RwLockWriteGuard; use tokio::sync::{broadcast::Receiver, RwLock, RwLockReadGuard}; use crate::locker::LockType; use crate::{locker::Guard, Locker, PatchDb, Revision, Store, Transaction}; use crate::{patch::DiffPatch, Error}; -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct HandleId(pub(crate) u64); +#[derive(Debug, Clone, Default)] +pub struct HandleId { + pub(crate) id: u64, + #[cfg(feature = "trace")] + pub(crate) trace: Option>, +} +impl PartialEq for HandleId { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} +impl Eq for HandleId {} #[async_trait] pub trait DbHandle: Send + Sync { @@ -42,7 +53,11 @@ pub trait DbHandle: Send + Sync { ptr: &JsonPointer, value: &Value, ) -> Result>, Error>; - async fn apply(&mut self, patch: DiffPatch) -> Result>, Error>; + async fn apply( + &mut self, + patch: DiffPatch, + store_write_lock: Option>, + ) -> Result>, Error>; async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> (); async fn get< T: for<'de> Deserialize<'de>, @@ -122,8 +137,12 @@ impl DbHandle for &mut Handle { ) -> Result>, Error> { (*self).put_value(ptr, value).await } - async fn apply(&mut self, patch: DiffPatch) -> Result>, Error> { - (*self).apply(patch).await + async fn apply( + &mut self, + patch: DiffPatch, + store_write_lock: Option>, + ) -> Result>, Error> { + (*self).apply(patch, store_write_lock).await } async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) { (*self).lock(ptr, lock_type).await @@ -230,8 +249,12 @@ impl DbHandle for PatchDbHandle { ) -> Result>, Error> { self.db.put(ptr, value, None).await } - async fn apply(&mut self, patch: DiffPatch) -> Result>, Error> { - self.db.apply(patch, None, None).await + async fn apply( + &mut self, + patch: DiffPatch, + store_write_lock: Option>, + ) -> Result>, Error> { + self.db.apply(patch, None, store_write_lock).await } async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) { self.locks diff --git a/patch-db/src/locker.rs b/patch-db/src/locker.rs index c52ec5e..4c0e983 100644 --- a/patch-db/src/locker.rs +++ b/patch-db/src/locker.rs @@ -21,35 +21,70 @@ impl Locker { // instead we want it to block forever by adding a channel that will never recv let (_dummy_send, dummy_recv) = oneshot::channel(); let mut locks_on_lease = vec![dummy_recv]; - while let Some(action) = get_action(&mut new_requests, &mut locks_on_lease).await { - #[cfg(feature = "log")] - log::trace!("Locker Action: {:#?}", action); + let (_dummy_send, dummy_recv) = oneshot::channel(); + let mut cancellations = vec![dummy_recv]; + while let Some(action) = + get_action(&mut new_requests, &mut locks_on_lease, &mut cancellations).await + { + #[cfg(feature = "tracing")] + tracing::trace!("Locker Action: {:#?}", action); match action { - Action::HandleRequest(req) => trie.handle_request(req, &mut locks_on_lease), + Action::HandleRequest(mut req) => { + cancellations.extend(req.cancel.take()); + trie.handle_request(req, &mut locks_on_lease) + } Action::HandleRelease(lock_info) => { trie.handle_release(lock_info, &mut locks_on_lease) } + Action::HandleCancel(lock_info) => { + trie.handle_cancel(lock_info, &mut locks_on_lease) + } } - #[cfg(feature = "log")] - log::trace!("Locker Trie: {:#?}", trie); + #[cfg(feature = "tracing")] + tracing::trace!("Locker Trie: {:#?}", trie); } }); Locker { sender } } pub async fn lock(&self, handle_id: HandleId, ptr: JsonPointer, lock_type: LockType) -> Guard { + struct CancelGuard { + lock_info: Option, + channel: Option>, + recv: oneshot::Receiver, + } + impl Drop for CancelGuard { + fn drop(&mut self) { + if let (Some(lock_info), Some(channel)) = + (self.lock_info.take(), self.channel.take()) + { + self.recv.close(); + let _ = channel.send(lock_info); + } + } + } + let lock_info = LockInfo { + handle_id, + ptr, + ty: lock_type, + segments_handled: 0, + }; let (send, recv) = oneshot::channel(); + let (cancel_send, cancel_recv) = oneshot::channel(); + let mut cancel_guard = CancelGuard { + lock_info: Some(lock_info.clone()), + channel: Some(cancel_send), + recv, + }; self.sender .send(Request { - lock_info: LockInfo { - handle_id, - ptr, - ty: lock_type, - segments_handled: 0, - }, + lock_info, + cancel: Some(cancel_recv), completion: send, }) .unwrap(); - recv.await.unwrap() + let res = (&mut cancel_guard.recv).await.unwrap(); + cancel_guard.channel.take(); + res } } @@ -62,13 +97,15 @@ struct RequestQueue { enum Action { HandleRequest(Request), HandleRelease(LockInfo), + HandleCancel(LockInfo), } async fn get_action( new_requests: &mut RequestQueue, locks_on_lease: &mut Vec>, + cancellations: &mut Vec>, ) -> Option { loop { - if new_requests.closed && locks_on_lease.is_empty() { + if new_requests.closed && locks_on_lease.len() == 1 && cancellations.len() == 1 { return None; } tokio::select! { @@ -83,6 +120,12 @@ async fn get_action( locks_on_lease.swap_remove(idx); return Some(Action::HandleRelease(a.unwrap())) } + (a, idx, _) = futures::future::select_all(cancellations.iter_mut()) => { + cancellations.swap_remove(idx); + if let Ok(a) = a { + return Some(Action::HandleCancel(a)) + } + } } } } @@ -123,6 +166,20 @@ impl Trie { .handle_release(release, locks_on_lease) } } + fn handle_cancel( + &mut self, + lock_info: LockInfo, + locks_on_lease: &mut Vec>, + ) { + let cancel = self.node.cancel(lock_info); + for req in std::mem::take(&mut self.node.reqs) { + self.handle_request(req, locks_on_lease); + } + if let Some(cancel) = cancel { + self.child_mut(cancel.current_seg()) + .handle_cancel(cancel, locks_on_lease) + } + } } #[derive(Debug, Default)] @@ -273,9 +330,26 @@ impl Node { Some(lock_info) } } + fn cancel(&mut self, mut lock_info: LockInfo) -> Option { + let mut idx = 0; + while idx < self.reqs.len() { + if self.reqs[idx].completion.is_closed() && self.reqs[idx].lock_info == lock_info { + self.reqs.swap_remove(idx); + return None; + } else { + idx += 1; + } + } + if lock_info.ptr.len() == lock_info.segments_handled { + None + } else { + lock_info.segments_handled += 1; + Some(lock_info) + } + } } -#[derive(Debug, Default)] +#[derive(Debug, Clone, Default, PartialEq)] struct LockInfo { ptr: JsonPointer, segments_handled: usize, @@ -313,6 +387,7 @@ impl Default for LockType { #[derive(Debug)] struct Request { lock_info: LockInfo, + cancel: Option>, completion: oneshot::Sender, } impl Request { @@ -328,10 +403,13 @@ impl Request { fn complete(self, locks_on_lease: &mut Vec>) { let (sender, receiver) = oneshot::channel(); locks_on_lease.push(receiver); - let _ = self.completion.send(Guard { + if let Err(_) = self.completion.send(Guard { lock_info: self.lock_info.reset(), sender: Some(sender), - }); + }) { + #[cfg(feature = "tracing")] + tracing::warn!("Completion sent to closed channel.") + } } } @@ -342,10 +420,14 @@ pub struct Guard { } impl Drop for Guard { fn drop(&mut self) { - let _ = self + if let Err(_e) = self .sender .take() .unwrap() - .send(std::mem::take(&mut self.lock_info)); + .send(std::mem::take(&mut self.lock_info)) + { + #[cfg(feature = "tracing")] + tracing::warn!("Failed to release lock: {:?}", _e) + } } } diff --git a/patch-db/src/model.rs b/patch-db/src/model.rs index d9eb1dc..1791005 100644 --- a/patch-db/src/model.rs +++ b/patch-db/src/model.rs @@ -33,13 +33,14 @@ pub struct ModelDataMut Deserialize<'de>> { ptr: JsonPointer, } impl Deserialize<'de>> ModelDataMut { - pub async fn save(self, db: &mut Db) -> Result<(), Error> { + pub async fn save(&mut self, db: &mut Db) -> Result<(), Error> { let current = serde_json::to_value(&self.current)?; let mut diff = crate::patch::diff(&self.original, ¤t); let target = db.get_value(&self.ptr, None).await?; diff.rebase(&crate::patch::diff(&self.original, &target)); diff.prepend(&self.ptr); - db.apply(diff).await?; + db.apply(diff, None).await?; + self.original = current; Ok(()) } } @@ -501,11 +502,12 @@ where pub async fn remove(&self, db: &mut Db, key: &T::Key) -> Result<(), Error> { db.lock(self.as_ref().clone(), LockType::Write).await; if db.exists(self.clone().idx(key).as_ref(), None).await? { - db.apply(DiffPatch(Patch(vec![PatchOperation::Remove( - RemoveOperation { + db.apply( + DiffPatch(Patch(vec![PatchOperation::Remove(RemoveOperation { path: self.as_ref().clone().join_end(key.as_ref()), - }, - )]))) + })])), + None, + ) .await?; } Ok(()) diff --git a/patch-db/src/store.rs b/patch-db/src/store.rs index dded388..42bd4ae 100644 --- a/patch-db/src/store.rs +++ b/patch-db/src/store.rs @@ -177,8 +177,8 @@ impl Store { return Ok(None); } - #[cfg(feature = "log")] - log::trace!("Attempting to apply patch: {:?}", patch); + #[cfg(feature = "tracing")] + tracing::trace!("Attempting to apply patch: {:?}", patch); self.check_cache_corrupted()?; let patch_bin = serde_cbor::to_vec(&*patch)?; @@ -288,10 +288,13 @@ impl PatchDb { } pub fn handle(&self) -> PatchDbHandle { PatchDbHandle { - id: HandleId( - self.handle_id + id: HandleId { + id: self + .handle_id .fetch_add(1, std::sync::atomic::Ordering::SeqCst), - ), + #[cfg(feature = "trace")] + trace: Some(Arc::new(tracing_error::SpanTrace::capture())), + }, db: self.clone(), locks: Vec::new(), } diff --git a/patch-db/src/transaction.rs b/patch-db/src/transaction.rs index 44045d0..24937b6 100644 --- a/patch-db/src/transaction.rs +++ b/patch-db/src/transaction.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use tokio::sync::broadcast::error::TryRecvError; use tokio::sync::broadcast::Receiver; -use tokio::sync::{RwLock, RwLockReadGuard}; +use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use crate::handle::HandleId; use crate::locker::{Guard, LockType, Locker}; @@ -51,10 +51,9 @@ impl Transaction<&mut PatchDbHandle> { impl Transaction { pub async fn save(mut self) -> Result<(), Error> { let store_lock = self.parent.store(); - let store = store_lock.read().await; + let store = store_lock.write().await; self.rebase()?; - self.parent.apply(self.updates).await?; - drop(store); + self.parent.apply(self.updates, Some(store)).await?; Ok(()) } } @@ -148,8 +147,8 @@ impl DbHandle for Transaction { }; let path_updates = self.updates.for_path(ptr); if !(path_updates.0).0.is_empty() { - #[cfg(feature = "log")] - log::trace!("applying patch {:?} at path {}", path_updates, ptr); + #[cfg(feature = "tracing")] + tracing::trace!("Applying patch {:?} at path {}", path_updates, ptr); json_patch::patch(&mut data, &*path_updates)?; } @@ -195,7 +194,11 @@ impl DbHandle for Transaction { ) -> Result>, Error> { self.put_value(ptr, &serde_json::to_value(value)?).await } - async fn apply(&mut self, patch: DiffPatch) -> Result>, Error> { + async fn apply( + &mut self, + patch: DiffPatch, + _store_write_lock: Option>, + ) -> Result>, Error> { self.updates.append(patch); Ok(None) }