mirror of
https://github.com/Start9Labs/patch-db.git
synced 2026-03-26 02:11:54 +00:00
cancel safety
This commit is contained in:
@@ -12,7 +12,8 @@ version = "0.1.0"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[features]
|
||||
debug = ["log"]
|
||||
debug = ["tracing"]
|
||||
trace = ["debug", "tracing-error"]
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.42"
|
||||
@@ -21,8 +22,9 @@ futures = "0.3.8"
|
||||
json-patch = { path = "../json-patch" }
|
||||
json-ptr = { path = "../json-ptr" }
|
||||
lazy_static = "1.4.0"
|
||||
log = { version = "*", optional = true }
|
||||
nix = "0.22.1"
|
||||
tracing = { version = "0.1.29", optional = true }
|
||||
tracing-error = { version = "0.1.2", optional = true }
|
||||
nix = "0.23.0"
|
||||
patch-db-macro = { path = "../patch-db-macro" }
|
||||
serde = { version = "1.0.118", features = ["rc"] }
|
||||
serde_cbor = { path = "../cbor" }
|
||||
|
||||
@@ -5,14 +5,25 @@ use json_ptr::{JsonPointer, SegList};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::BTreeSet;
|
||||
use tokio::sync::RwLockWriteGuard;
|
||||
use tokio::sync::{broadcast::Receiver, RwLock, RwLockReadGuard};
|
||||
|
||||
use crate::locker::LockType;
|
||||
use crate::{locker::Guard, Locker, PatchDb, Revision, Store, Transaction};
|
||||
use crate::{patch::DiffPatch, Error};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct HandleId(pub(crate) u64);
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct HandleId {
|
||||
pub(crate) id: u64,
|
||||
#[cfg(feature = "trace")]
|
||||
pub(crate) trace: Option<Arc<tracing_error::SpanTrace>>,
|
||||
}
|
||||
impl PartialEq for HandleId {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.id == other.id
|
||||
}
|
||||
}
|
||||
impl Eq for HandleId {}
|
||||
|
||||
#[async_trait]
|
||||
pub trait DbHandle: Send + Sync {
|
||||
@@ -42,7 +53,11 @@ pub trait DbHandle: Send + Sync {
|
||||
ptr: &JsonPointer<S, V>,
|
||||
value: &Value,
|
||||
) -> Result<Option<Arc<Revision>>, Error>;
|
||||
async fn apply(&mut self, patch: DiffPatch) -> Result<Option<Arc<Revision>>, Error>;
|
||||
async fn apply(
|
||||
&mut self,
|
||||
patch: DiffPatch,
|
||||
store_write_lock: Option<RwLockWriteGuard<'_, Store>>,
|
||||
) -> Result<Option<Arc<Revision>>, Error>;
|
||||
async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> ();
|
||||
async fn get<
|
||||
T: for<'de> Deserialize<'de>,
|
||||
@@ -122,8 +137,12 @@ impl<Handle: DbHandle + ?Sized> DbHandle for &mut Handle {
|
||||
) -> Result<Option<Arc<Revision>>, Error> {
|
||||
(*self).put_value(ptr, value).await
|
||||
}
|
||||
async fn apply(&mut self, patch: DiffPatch) -> Result<Option<Arc<Revision>>, Error> {
|
||||
(*self).apply(patch).await
|
||||
async fn apply(
|
||||
&mut self,
|
||||
patch: DiffPatch,
|
||||
store_write_lock: Option<RwLockWriteGuard<'_, Store>>,
|
||||
) -> Result<Option<Arc<Revision>>, Error> {
|
||||
(*self).apply(patch, store_write_lock).await
|
||||
}
|
||||
async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) {
|
||||
(*self).lock(ptr, lock_type).await
|
||||
@@ -230,8 +249,12 @@ impl DbHandle for PatchDbHandle {
|
||||
) -> Result<Option<Arc<Revision>>, Error> {
|
||||
self.db.put(ptr, value, None).await
|
||||
}
|
||||
async fn apply(&mut self, patch: DiffPatch) -> Result<Option<Arc<Revision>>, Error> {
|
||||
self.db.apply(patch, None, None).await
|
||||
async fn apply(
|
||||
&mut self,
|
||||
patch: DiffPatch,
|
||||
store_write_lock: Option<RwLockWriteGuard<'_, Store>>,
|
||||
) -> Result<Option<Arc<Revision>>, Error> {
|
||||
self.db.apply(patch, None, store_write_lock).await
|
||||
}
|
||||
async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) {
|
||||
self.locks
|
||||
|
||||
@@ -21,35 +21,70 @@ impl Locker {
|
||||
// instead we want it to block forever by adding a channel that will never recv
|
||||
let (_dummy_send, dummy_recv) = oneshot::channel();
|
||||
let mut locks_on_lease = vec![dummy_recv];
|
||||
while let Some(action) = get_action(&mut new_requests, &mut locks_on_lease).await {
|
||||
#[cfg(feature = "log")]
|
||||
log::trace!("Locker Action: {:#?}", action);
|
||||
let (_dummy_send, dummy_recv) = oneshot::channel();
|
||||
let mut cancellations = vec![dummy_recv];
|
||||
while let Some(action) =
|
||||
get_action(&mut new_requests, &mut locks_on_lease, &mut cancellations).await
|
||||
{
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("Locker Action: {:#?}", action);
|
||||
match action {
|
||||
Action::HandleRequest(req) => trie.handle_request(req, &mut locks_on_lease),
|
||||
Action::HandleRequest(mut req) => {
|
||||
cancellations.extend(req.cancel.take());
|
||||
trie.handle_request(req, &mut locks_on_lease)
|
||||
}
|
||||
Action::HandleRelease(lock_info) => {
|
||||
trie.handle_release(lock_info, &mut locks_on_lease)
|
||||
}
|
||||
Action::HandleCancel(lock_info) => {
|
||||
trie.handle_cancel(lock_info, &mut locks_on_lease)
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "log")]
|
||||
log::trace!("Locker Trie: {:#?}", trie);
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("Locker Trie: {:#?}", trie);
|
||||
}
|
||||
});
|
||||
Locker { sender }
|
||||
}
|
||||
pub async fn lock(&self, handle_id: HandleId, ptr: JsonPointer, lock_type: LockType) -> Guard {
|
||||
struct CancelGuard {
|
||||
lock_info: Option<LockInfo>,
|
||||
channel: Option<oneshot::Sender<LockInfo>>,
|
||||
recv: oneshot::Receiver<Guard>,
|
||||
}
|
||||
impl Drop for CancelGuard {
|
||||
fn drop(&mut self) {
|
||||
if let (Some(lock_info), Some(channel)) =
|
||||
(self.lock_info.take(), self.channel.take())
|
||||
{
|
||||
self.recv.close();
|
||||
let _ = channel.send(lock_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
let lock_info = LockInfo {
|
||||
handle_id,
|
||||
ptr,
|
||||
ty: lock_type,
|
||||
segments_handled: 0,
|
||||
};
|
||||
let (send, recv) = oneshot::channel();
|
||||
let (cancel_send, cancel_recv) = oneshot::channel();
|
||||
let mut cancel_guard = CancelGuard {
|
||||
lock_info: Some(lock_info.clone()),
|
||||
channel: Some(cancel_send),
|
||||
recv,
|
||||
};
|
||||
self.sender
|
||||
.send(Request {
|
||||
lock_info: LockInfo {
|
||||
handle_id,
|
||||
ptr,
|
||||
ty: lock_type,
|
||||
segments_handled: 0,
|
||||
},
|
||||
lock_info,
|
||||
cancel: Some(cancel_recv),
|
||||
completion: send,
|
||||
})
|
||||
.unwrap();
|
||||
recv.await.unwrap()
|
||||
let res = (&mut cancel_guard.recv).await.unwrap();
|
||||
cancel_guard.channel.take();
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,13 +97,15 @@ struct RequestQueue {
|
||||
enum Action {
|
||||
HandleRequest(Request),
|
||||
HandleRelease(LockInfo),
|
||||
HandleCancel(LockInfo),
|
||||
}
|
||||
async fn get_action(
|
||||
new_requests: &mut RequestQueue,
|
||||
locks_on_lease: &mut Vec<oneshot::Receiver<LockInfo>>,
|
||||
cancellations: &mut Vec<oneshot::Receiver<LockInfo>>,
|
||||
) -> Option<Action> {
|
||||
loop {
|
||||
if new_requests.closed && locks_on_lease.is_empty() {
|
||||
if new_requests.closed && locks_on_lease.len() == 1 && cancellations.len() == 1 {
|
||||
return None;
|
||||
}
|
||||
tokio::select! {
|
||||
@@ -83,6 +120,12 @@ async fn get_action(
|
||||
locks_on_lease.swap_remove(idx);
|
||||
return Some(Action::HandleRelease(a.unwrap()))
|
||||
}
|
||||
(a, idx, _) = futures::future::select_all(cancellations.iter_mut()) => {
|
||||
cancellations.swap_remove(idx);
|
||||
if let Ok(a) = a {
|
||||
return Some(Action::HandleCancel(a))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -123,6 +166,20 @@ impl Trie {
|
||||
.handle_release(release, locks_on_lease)
|
||||
}
|
||||
}
|
||||
fn handle_cancel(
|
||||
&mut self,
|
||||
lock_info: LockInfo,
|
||||
locks_on_lease: &mut Vec<oneshot::Receiver<LockInfo>>,
|
||||
) {
|
||||
let cancel = self.node.cancel(lock_info);
|
||||
for req in std::mem::take(&mut self.node.reqs) {
|
||||
self.handle_request(req, locks_on_lease);
|
||||
}
|
||||
if let Some(cancel) = cancel {
|
||||
self.child_mut(cancel.current_seg())
|
||||
.handle_cancel(cancel, locks_on_lease)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
@@ -273,9 +330,26 @@ impl Node {
|
||||
Some(lock_info)
|
||||
}
|
||||
}
|
||||
fn cancel(&mut self, mut lock_info: LockInfo) -> Option<LockInfo> {
|
||||
let mut idx = 0;
|
||||
while idx < self.reqs.len() {
|
||||
if self.reqs[idx].completion.is_closed() && self.reqs[idx].lock_info == lock_info {
|
||||
self.reqs.swap_remove(idx);
|
||||
return None;
|
||||
} else {
|
||||
idx += 1;
|
||||
}
|
||||
}
|
||||
if lock_info.ptr.len() == lock_info.segments_handled {
|
||||
None
|
||||
} else {
|
||||
lock_info.segments_handled += 1;
|
||||
Some(lock_info)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
struct LockInfo {
|
||||
ptr: JsonPointer,
|
||||
segments_handled: usize,
|
||||
@@ -313,6 +387,7 @@ impl Default for LockType {
|
||||
#[derive(Debug)]
|
||||
struct Request {
|
||||
lock_info: LockInfo,
|
||||
cancel: Option<oneshot::Receiver<LockInfo>>,
|
||||
completion: oneshot::Sender<Guard>,
|
||||
}
|
||||
impl Request {
|
||||
@@ -328,10 +403,13 @@ impl Request {
|
||||
fn complete(self, locks_on_lease: &mut Vec<oneshot::Receiver<LockInfo>>) {
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
locks_on_lease.push(receiver);
|
||||
let _ = self.completion.send(Guard {
|
||||
if let Err(_) = self.completion.send(Guard {
|
||||
lock_info: self.lock_info.reset(),
|
||||
sender: Some(sender),
|
||||
});
|
||||
}) {
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::warn!("Completion sent to closed channel.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -342,10 +420,14 @@ pub struct Guard {
|
||||
}
|
||||
impl Drop for Guard {
|
||||
fn drop(&mut self) {
|
||||
let _ = self
|
||||
if let Err(_e) = self
|
||||
.sender
|
||||
.take()
|
||||
.unwrap()
|
||||
.send(std::mem::take(&mut self.lock_info));
|
||||
.send(std::mem::take(&mut self.lock_info))
|
||||
{
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::warn!("Failed to release lock: {:?}", _e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,13 +33,14 @@ pub struct ModelDataMut<T: Serialize + for<'de> Deserialize<'de>> {
|
||||
ptr: JsonPointer,
|
||||
}
|
||||
impl<T: Serialize + for<'de> Deserialize<'de>> ModelDataMut<T> {
|
||||
pub async fn save<Db: DbHandle>(self, db: &mut Db) -> Result<(), Error> {
|
||||
pub async fn save<Db: DbHandle>(&mut self, db: &mut Db) -> Result<(), Error> {
|
||||
let current = serde_json::to_value(&self.current)?;
|
||||
let mut diff = crate::patch::diff(&self.original, ¤t);
|
||||
let target = db.get_value(&self.ptr, None).await?;
|
||||
diff.rebase(&crate::patch::diff(&self.original, &target));
|
||||
diff.prepend(&self.ptr);
|
||||
db.apply(diff).await?;
|
||||
db.apply(diff, None).await?;
|
||||
self.original = current;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -501,11 +502,12 @@ where
|
||||
pub async fn remove<Db: DbHandle>(&self, db: &mut Db, key: &T::Key) -> Result<(), Error> {
|
||||
db.lock(self.as_ref().clone(), LockType::Write).await;
|
||||
if db.exists(self.clone().idx(key).as_ref(), None).await? {
|
||||
db.apply(DiffPatch(Patch(vec![PatchOperation::Remove(
|
||||
RemoveOperation {
|
||||
db.apply(
|
||||
DiffPatch(Patch(vec![PatchOperation::Remove(RemoveOperation {
|
||||
path: self.as_ref().clone().join_end(key.as_ref()),
|
||||
},
|
||||
)])))
|
||||
})])),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
|
||||
@@ -177,8 +177,8 @@ impl Store {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
#[cfg(feature = "log")]
|
||||
log::trace!("Attempting to apply patch: {:?}", patch);
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("Attempting to apply patch: {:?}", patch);
|
||||
|
||||
self.check_cache_corrupted()?;
|
||||
let patch_bin = serde_cbor::to_vec(&*patch)?;
|
||||
@@ -288,10 +288,13 @@ impl PatchDb {
|
||||
}
|
||||
pub fn handle(&self) -> PatchDbHandle {
|
||||
PatchDbHandle {
|
||||
id: HandleId(
|
||||
self.handle_id
|
||||
id: HandleId {
|
||||
id: self
|
||||
.handle_id
|
||||
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
|
||||
),
|
||||
#[cfg(feature = "trace")]
|
||||
trace: Some(Arc::new(tracing_error::SpanTrace::capture())),
|
||||
},
|
||||
db: self.clone(),
|
||||
locks: Vec::new(),
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::broadcast::error::TryRecvError;
|
||||
use tokio::sync::broadcast::Receiver;
|
||||
use tokio::sync::{RwLock, RwLockReadGuard};
|
||||
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
|
||||
|
||||
use crate::handle::HandleId;
|
||||
use crate::locker::{Guard, LockType, Locker};
|
||||
@@ -51,10 +51,9 @@ impl Transaction<&mut PatchDbHandle> {
|
||||
impl<Parent: DbHandle + Send + Sync> Transaction<Parent> {
|
||||
pub async fn save(mut self) -> Result<(), Error> {
|
||||
let store_lock = self.parent.store();
|
||||
let store = store_lock.read().await;
|
||||
let store = store_lock.write().await;
|
||||
self.rebase()?;
|
||||
self.parent.apply(self.updates).await?;
|
||||
drop(store);
|
||||
self.parent.apply(self.updates, Some(store)).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -148,8 +147,8 @@ impl<Parent: DbHandle + Send + Sync> DbHandle for Transaction<Parent> {
|
||||
};
|
||||
let path_updates = self.updates.for_path(ptr);
|
||||
if !(path_updates.0).0.is_empty() {
|
||||
#[cfg(feature = "log")]
|
||||
log::trace!("applying patch {:?} at path {}", path_updates, ptr);
|
||||
#[cfg(feature = "tracing")]
|
||||
tracing::trace!("Applying patch {:?} at path {}", path_updates, ptr);
|
||||
|
||||
json_patch::patch(&mut data, &*path_updates)?;
|
||||
}
|
||||
@@ -195,7 +194,11 @@ impl<Parent: DbHandle + Send + Sync> DbHandle for Transaction<Parent> {
|
||||
) -> Result<Option<Arc<Revision>>, Error> {
|
||||
self.put_value(ptr, &serde_json::to_value(value)?).await
|
||||
}
|
||||
async fn apply(&mut self, patch: DiffPatch) -> Result<Option<Arc<Revision>>, Error> {
|
||||
async fn apply(
|
||||
&mut self,
|
||||
patch: DiffPatch,
|
||||
_store_write_lock: Option<RwLockWriteGuard<'_, Store>>,
|
||||
) -> Result<Option<Arc<Revision>>, Error> {
|
||||
self.updates.append(patch);
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user