cancel safety

This commit is contained in:
Aiden McClelland
2021-10-21 16:31:07 -06:00
parent 7bb573abb8
commit 16ba75225b
6 changed files with 162 additions and 47 deletions

View File

@@ -12,7 +12,8 @@ version = "0.1.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
debug = ["log"]
debug = ["tracing"]
trace = ["debug", "tracing-error"]
[dependencies]
async-trait = "0.1.42"
@@ -21,8 +22,9 @@ futures = "0.3.8"
json-patch = { path = "../json-patch" }
json-ptr = { path = "../json-ptr" }
lazy_static = "1.4.0"
log = { version = "*", optional = true }
nix = "0.22.1"
tracing = { version = "0.1.29", optional = true }
tracing-error = { version = "0.1.2", optional = true }
nix = "0.23.0"
patch-db-macro = { path = "../patch-db-macro" }
serde = { version = "1.0.118", features = ["rc"] }
serde_cbor = { path = "../cbor" }

View File

@@ -5,14 +5,25 @@ use json_ptr::{JsonPointer, SegList};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::BTreeSet;
use tokio::sync::RwLockWriteGuard;
use tokio::sync::{broadcast::Receiver, RwLock, RwLockReadGuard};
use crate::locker::LockType;
use crate::{locker::Guard, Locker, PatchDb, Revision, Store, Transaction};
use crate::{patch::DiffPatch, Error};
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct HandleId(pub(crate) u64);
#[derive(Debug, Clone, Default)]
pub struct HandleId {
pub(crate) id: u64,
#[cfg(feature = "trace")]
pub(crate) trace: Option<Arc<tracing_error::SpanTrace>>,
}
impl PartialEq for HandleId {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for HandleId {}
#[async_trait]
pub trait DbHandle: Send + Sync {
@@ -42,7 +53,11 @@ pub trait DbHandle: Send + Sync {
ptr: &JsonPointer<S, V>,
value: &Value,
) -> Result<Option<Arc<Revision>>, Error>;
async fn apply(&mut self, patch: DiffPatch) -> Result<Option<Arc<Revision>>, Error>;
async fn apply(
&mut self,
patch: DiffPatch,
store_write_lock: Option<RwLockWriteGuard<'_, Store>>,
) -> Result<Option<Arc<Revision>>, Error>;
async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> ();
async fn get<
T: for<'de> Deserialize<'de>,
@@ -122,8 +137,12 @@ impl<Handle: DbHandle + ?Sized> DbHandle for &mut Handle {
) -> Result<Option<Arc<Revision>>, Error> {
(*self).put_value(ptr, value).await
}
async fn apply(&mut self, patch: DiffPatch) -> Result<Option<Arc<Revision>>, Error> {
(*self).apply(patch).await
async fn apply(
&mut self,
patch: DiffPatch,
store_write_lock: Option<RwLockWriteGuard<'_, Store>>,
) -> Result<Option<Arc<Revision>>, Error> {
(*self).apply(patch, store_write_lock).await
}
async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) {
(*self).lock(ptr, lock_type).await
@@ -230,8 +249,12 @@ impl DbHandle for PatchDbHandle {
) -> Result<Option<Arc<Revision>>, Error> {
self.db.put(ptr, value, None).await
}
async fn apply(&mut self, patch: DiffPatch) -> Result<Option<Arc<Revision>>, Error> {
self.db.apply(patch, None, None).await
async fn apply(
&mut self,
patch: DiffPatch,
store_write_lock: Option<RwLockWriteGuard<'_, Store>>,
) -> Result<Option<Arc<Revision>>, Error> {
self.db.apply(patch, None, store_write_lock).await
}
async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) {
self.locks

View File

@@ -21,35 +21,70 @@ impl Locker {
// instead we want it to block forever by adding a channel that will never recv
let (_dummy_send, dummy_recv) = oneshot::channel();
let mut locks_on_lease = vec![dummy_recv];
while let Some(action) = get_action(&mut new_requests, &mut locks_on_lease).await {
#[cfg(feature = "log")]
log::trace!("Locker Action: {:#?}", action);
let (_dummy_send, dummy_recv) = oneshot::channel();
let mut cancellations = vec![dummy_recv];
while let Some(action) =
get_action(&mut new_requests, &mut locks_on_lease, &mut cancellations).await
{
#[cfg(feature = "tracing")]
tracing::trace!("Locker Action: {:#?}", action);
match action {
Action::HandleRequest(req) => trie.handle_request(req, &mut locks_on_lease),
Action::HandleRequest(mut req) => {
cancellations.extend(req.cancel.take());
trie.handle_request(req, &mut locks_on_lease)
}
Action::HandleRelease(lock_info) => {
trie.handle_release(lock_info, &mut locks_on_lease)
}
Action::HandleCancel(lock_info) => {
trie.handle_cancel(lock_info, &mut locks_on_lease)
}
}
#[cfg(feature = "log")]
log::trace!("Locker Trie: {:#?}", trie);
#[cfg(feature = "tracing")]
tracing::trace!("Locker Trie: {:#?}", trie);
}
});
Locker { sender }
}
pub async fn lock(&self, handle_id: HandleId, ptr: JsonPointer, lock_type: LockType) -> Guard {
struct CancelGuard {
lock_info: Option<LockInfo>,
channel: Option<oneshot::Sender<LockInfo>>,
recv: oneshot::Receiver<Guard>,
}
impl Drop for CancelGuard {
fn drop(&mut self) {
if let (Some(lock_info), Some(channel)) =
(self.lock_info.take(), self.channel.take())
{
self.recv.close();
let _ = channel.send(lock_info);
}
}
}
let lock_info = LockInfo {
handle_id,
ptr,
ty: lock_type,
segments_handled: 0,
};
let (send, recv) = oneshot::channel();
let (cancel_send, cancel_recv) = oneshot::channel();
let mut cancel_guard = CancelGuard {
lock_info: Some(lock_info.clone()),
channel: Some(cancel_send),
recv,
};
self.sender
.send(Request {
lock_info: LockInfo {
handle_id,
ptr,
ty: lock_type,
segments_handled: 0,
},
lock_info,
cancel: Some(cancel_recv),
completion: send,
})
.unwrap();
recv.await.unwrap()
let res = (&mut cancel_guard.recv).await.unwrap();
cancel_guard.channel.take();
res
}
}
@@ -62,13 +97,15 @@ struct RequestQueue {
enum Action {
HandleRequest(Request),
HandleRelease(LockInfo),
HandleCancel(LockInfo),
}
async fn get_action(
new_requests: &mut RequestQueue,
locks_on_lease: &mut Vec<oneshot::Receiver<LockInfo>>,
cancellations: &mut Vec<oneshot::Receiver<LockInfo>>,
) -> Option<Action> {
loop {
if new_requests.closed && locks_on_lease.is_empty() {
if new_requests.closed && locks_on_lease.len() == 1 && cancellations.len() == 1 {
return None;
}
tokio::select! {
@@ -83,6 +120,12 @@ async fn get_action(
locks_on_lease.swap_remove(idx);
return Some(Action::HandleRelease(a.unwrap()))
}
(a, idx, _) = futures::future::select_all(cancellations.iter_mut()) => {
cancellations.swap_remove(idx);
if let Ok(a) = a {
return Some(Action::HandleCancel(a))
}
}
}
}
}
@@ -123,6 +166,20 @@ impl Trie {
.handle_release(release, locks_on_lease)
}
}
fn handle_cancel(
&mut self,
lock_info: LockInfo,
locks_on_lease: &mut Vec<oneshot::Receiver<LockInfo>>,
) {
let cancel = self.node.cancel(lock_info);
for req in std::mem::take(&mut self.node.reqs) {
self.handle_request(req, locks_on_lease);
}
if let Some(cancel) = cancel {
self.child_mut(cancel.current_seg())
.handle_cancel(cancel, locks_on_lease)
}
}
}
#[derive(Debug, Default)]
@@ -273,9 +330,26 @@ impl Node {
Some(lock_info)
}
}
fn cancel(&mut self, mut lock_info: LockInfo) -> Option<LockInfo> {
let mut idx = 0;
while idx < self.reqs.len() {
if self.reqs[idx].completion.is_closed() && self.reqs[idx].lock_info == lock_info {
self.reqs.swap_remove(idx);
return None;
} else {
idx += 1;
}
}
if lock_info.ptr.len() == lock_info.segments_handled {
None
} else {
lock_info.segments_handled += 1;
Some(lock_info)
}
}
}
#[derive(Debug, Default)]
#[derive(Debug, Clone, Default, PartialEq)]
struct LockInfo {
ptr: JsonPointer,
segments_handled: usize,
@@ -313,6 +387,7 @@ impl Default for LockType {
#[derive(Debug)]
struct Request {
lock_info: LockInfo,
cancel: Option<oneshot::Receiver<LockInfo>>,
completion: oneshot::Sender<Guard>,
}
impl Request {
@@ -328,10 +403,13 @@ impl Request {
fn complete(self, locks_on_lease: &mut Vec<oneshot::Receiver<LockInfo>>) {
let (sender, receiver) = oneshot::channel();
locks_on_lease.push(receiver);
let _ = self.completion.send(Guard {
if let Err(_) = self.completion.send(Guard {
lock_info: self.lock_info.reset(),
sender: Some(sender),
});
}) {
#[cfg(feature = "tracing")]
tracing::warn!("Completion sent to closed channel.")
}
}
}
@@ -342,10 +420,14 @@ pub struct Guard {
}
impl Drop for Guard {
fn drop(&mut self) {
let _ = self
if let Err(_e) = self
.sender
.take()
.unwrap()
.send(std::mem::take(&mut self.lock_info));
.send(std::mem::take(&mut self.lock_info))
{
#[cfg(feature = "tracing")]
tracing::warn!("Failed to release lock: {:?}", _e)
}
}
}

View File

@@ -33,13 +33,14 @@ pub struct ModelDataMut<T: Serialize + for<'de> Deserialize<'de>> {
ptr: JsonPointer,
}
impl<T: Serialize + for<'de> Deserialize<'de>> ModelDataMut<T> {
pub async fn save<Db: DbHandle>(self, db: &mut Db) -> Result<(), Error> {
pub async fn save<Db: DbHandle>(&mut self, db: &mut Db) -> Result<(), Error> {
let current = serde_json::to_value(&self.current)?;
let mut diff = crate::patch::diff(&self.original, &current);
let target = db.get_value(&self.ptr, None).await?;
diff.rebase(&crate::patch::diff(&self.original, &target));
diff.prepend(&self.ptr);
db.apply(diff).await?;
db.apply(diff, None).await?;
self.original = current;
Ok(())
}
}
@@ -501,11 +502,12 @@ where
pub async fn remove<Db: DbHandle>(&self, db: &mut Db, key: &T::Key) -> Result<(), Error> {
db.lock(self.as_ref().clone(), LockType::Write).await;
if db.exists(self.clone().idx(key).as_ref(), None).await? {
db.apply(DiffPatch(Patch(vec![PatchOperation::Remove(
RemoveOperation {
db.apply(
DiffPatch(Patch(vec![PatchOperation::Remove(RemoveOperation {
path: self.as_ref().clone().join_end(key.as_ref()),
},
)])))
})])),
None,
)
.await?;
}
Ok(())

View File

@@ -177,8 +177,8 @@ impl Store {
return Ok(None);
}
#[cfg(feature = "log")]
log::trace!("Attempting to apply patch: {:?}", patch);
#[cfg(feature = "tracing")]
tracing::trace!("Attempting to apply patch: {:?}", patch);
self.check_cache_corrupted()?;
let patch_bin = serde_cbor::to_vec(&*patch)?;
@@ -288,10 +288,13 @@ impl PatchDb {
}
pub fn handle(&self) -> PatchDbHandle {
PatchDbHandle {
id: HandleId(
self.handle_id
id: HandleId {
id: self
.handle_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
),
#[cfg(feature = "trace")]
trace: Some(Arc::new(tracing_error::SpanTrace::capture())),
},
db: self.clone(),
locks: Vec::new(),
}

View File

@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::broadcast::error::TryRecvError;
use tokio::sync::broadcast::Receiver;
use tokio::sync::{RwLock, RwLockReadGuard};
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use crate::handle::HandleId;
use crate::locker::{Guard, LockType, Locker};
@@ -51,10 +51,9 @@ impl Transaction<&mut PatchDbHandle> {
impl<Parent: DbHandle + Send + Sync> Transaction<Parent> {
pub async fn save(mut self) -> Result<(), Error> {
let store_lock = self.parent.store();
let store = store_lock.read().await;
let store = store_lock.write().await;
self.rebase()?;
self.parent.apply(self.updates).await?;
drop(store);
self.parent.apply(self.updates, Some(store)).await?;
Ok(())
}
}
@@ -148,8 +147,8 @@ impl<Parent: DbHandle + Send + Sync> DbHandle for Transaction<Parent> {
};
let path_updates = self.updates.for_path(ptr);
if !(path_updates.0).0.is_empty() {
#[cfg(feature = "log")]
log::trace!("applying patch {:?} at path {}", path_updates, ptr);
#[cfg(feature = "tracing")]
tracing::trace!("Applying patch {:?} at path {}", path_updates, ptr);
json_patch::patch(&mut data, &*path_updates)?;
}
@@ -195,7 +194,11 @@ impl<Parent: DbHandle + Send + Sync> DbHandle for Transaction<Parent> {
) -> Result<Option<Arc<Revision>>, Error> {
self.put_value(ptr, &serde_json::to_value(value)?).await
}
async fn apply(&mut self, patch: DiffPatch) -> Result<Option<Arc<Revision>>, Error> {
async fn apply(
&mut self,
patch: DiffPatch,
_store_write_lock: Option<RwLockWriteGuard<'_, Store>>,
) -> Result<Option<Arc<Revision>>, Error> {
self.updates.append(patch);
Ok(None)
}