diff --git a/patch-db/src/handle.rs b/patch-db/src/handle.rs index 7630b01..88e627c 100644 --- a/patch-db/src/handle.rs +++ b/patch-db/src/handle.rs @@ -1,16 +1,16 @@ +use std::collections::BTreeSet; use std::sync::Arc; use async_trait::async_trait; use json_ptr::{JsonPointer, SegList}; use serde::{Deserialize, Serialize}; use serde_json::Value; -use std::collections::BTreeSet; -use tokio::sync::RwLockWriteGuard; -use tokio::sync::{broadcast::Receiver, RwLock, RwLockReadGuard}; +use tokio::sync::broadcast::Receiver; +use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; -use crate::locker::LockType; -use crate::{locker::Guard, Locker, PatchDb, Revision, Store, Transaction}; -use crate::{patch::DiffPatch, Error}; +use crate::locker::{Guard, LockError, LockType}; +use crate::patch::DiffPatch; +use crate::{Error, Locker, PatchDb, Revision, Store, Transaction}; #[derive(Debug, Clone, Default)] pub struct HandleId { @@ -68,7 +68,7 @@ pub trait DbHandle: Send + Sync { patch: DiffPatch, store_write_lock: Option>, ) -> Result>, Error>; - async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> (); + async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> Result<(), LockError>; async fn get< T: for<'de> Deserialize<'de>, S: AsRef + Send + Sync, @@ -154,7 +154,7 @@ impl DbHandle for &mut Handle { ) -> Result>, Error> { (*self).apply(patch, store_write_lock).await } - async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) { + async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> Result<(), LockError> { (*self).lock(ptr, lock_type).await } async fn get< @@ -266,9 +266,10 @@ impl DbHandle for PatchDbHandle { ) -> Result>, Error> { self.db.apply(patch, None, store_write_lock).await } - async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) { - self.locks - .push(self.db.locker.lock(self.id.clone(), ptr, lock_type).await); + async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> Result<(), LockError> { + Ok(self + .locks + .push(self.db.locker.lock(self.id.clone(), ptr, lock_type).await?)) } async fn get< T: for<'de> Deserialize<'de>, diff --git a/patch-db/src/lib.rs b/patch-db/src/lib.rs index 45ced01..abd2958 100644 --- a/patch-db/src/lib.rs +++ b/patch-db/src/lib.rs @@ -2,6 +2,7 @@ use std::io::Error as IOError; use std::sync::Arc; use json_ptr::JsonPointer; +use locker::LockError; use thiserror::Error; use tokio::sync::broadcast::error::TryRecvError; @@ -20,8 +21,6 @@ mod proptest; mod test; pub use handle::{DbHandle, PatchDbHandle}; -pub use json_patch; -pub use json_ptr; pub use locker::{LockType, Locker}; pub use model::{ BoxModel, HasModel, Map, MapModel, Model, ModelData, ModelDataMut, OptionModel, VecModel, @@ -30,6 +29,7 @@ pub use patch::{DiffPatch, Dump, Revision}; pub use patch_db_macro::HasModel; pub use store::{PatchDb, Store}; pub use transaction::Transaction; +pub use {json_patch, json_ptr}; #[derive(Error, Debug)] pub enum Error { @@ -53,4 +53,6 @@ pub enum Error { Subscriber(#[from] TryRecvError), #[error("Node Does Not Exist: {0}")] NodeDoesNotExist(JsonPointer), + #[error("Invalid Lock Request: {0}")] + LockError(#[from] LockError), } diff --git a/patch-db/src/locker.rs b/patch-db/src/locker.rs index d63ac9d..7b61d3d 100644 --- a/patch-db/src/locker.rs +++ b/patch-db/src/locker.rs @@ -209,7 +209,12 @@ impl Locker { }); Locker { sender } } - pub async fn lock(&self, handle_id: HandleId, ptr: JsonPointer, lock_type: LockType) -> Guard { + pub async fn lock( + &self, + handle_id: HandleId, + ptr: JsonPointer, + lock_type: LockType, + ) -> Result { struct CancelGuard { lock_info: Option, channel: Option>, @@ -246,7 +251,7 @@ impl Locker { .unwrap(); let res = (&mut cancel_guard.recv).await.unwrap(); cancel_guard.channel.take(); - res + Ok(res) } } @@ -910,6 +915,31 @@ impl std::fmt::Display for LockType { } } +#[derive(Debug, Clone, thiserror::Error)] +pub enum LockError { + #[error("Lock Taxonomy Escalation: Session = {session:?}, First = {first}, Second = {second}")] + LockTaxonomyEscalation { + session: HandleId, + first: JsonPointer, + second: JsonPointer, + }, + #[error("Lock Type Escalation: Session = {session:?}, Pointer = {ptr}, First = {first}, Second = {second}")] + LockTypeEscalation { + session: HandleId, + ptr: JsonPointer, + first: LockType, + second: LockType, + }, + #[error( + "Non-Canonical Lock Ordering: Session = {session:?}, First = {first}, Second = {second}" + )] + NonCanonicalOrdering { + session: HandleId, + first: JsonPointer, + second: JsonPointer, + }, +} + #[derive(Debug)] struct Request { lock_info: LockInfo, diff --git a/patch-db/src/model.rs b/patch-db/src/model.rs index eb424a8..bbc9bc0 100644 --- a/patch-db/src/model.rs +++ b/patch-db/src/model.rs @@ -9,7 +9,7 @@ use json_ptr::JsonPointer; use serde::{Deserialize, Serialize}; use serde_json::Value; -use crate::locker::LockType; +use crate::locker::{LockError, LockType}; use crate::{DbHandle, DiffPatch, Error, Revision}; #[derive(Debug)] @@ -65,19 +65,23 @@ impl Model where T: Serialize + for<'de> Deserialize<'de>, { - pub async fn lock(&self, db: &mut Db, lock_type: LockType) { + pub async fn lock( + &self, + db: &mut Db, + lock_type: LockType, + ) -> Result<(), LockError> { db.lock(self.ptr.clone(), lock_type).await } pub async fn get(&self, db: &mut Db, lock: bool) -> Result, Error> { if lock { - self.lock(db, LockType::Read).await; + self.lock(db, LockType::Read).await?; } Ok(ModelData(db.get(&self.ptr).await?)) } pub async fn get_mut(&self, db: &mut Db) -> Result, Error> { - self.lock(db, LockType::Write).await; + self.lock(db, LockType::Write).await?; let original = db.get_value(&self.ptr, None).await?; let current = serde_json::from_value(original.clone())?; Ok(ModelDataMut { @@ -105,7 +109,7 @@ where db: &mut Db, value: &T, ) -> Result>, Error> { - self.lock(db, LockType::Write).await; + self.lock(db, LockType::Write).await?; db.put(&self.ptr, value).await } } @@ -228,7 +232,11 @@ impl Deserialize<'de>> HasModel for Box { #[derive(Debug)] pub struct OptionModel Deserialize<'de>>(T::Model); impl Deserialize<'de>> OptionModel { - pub async fn lock(&self, db: &mut Db, lock_type: LockType) { + pub async fn lock( + &self, + db: &mut Db, + lock_type: LockType, + ) -> Result<(), LockError> { db.lock(self.0.as_ref().clone(), lock_type).await } @@ -238,7 +246,7 @@ impl Deserialize<'de>> OptionModel { lock: bool, ) -> Result>, Error> { if lock { - self.lock(db, LockType::Read).await; + self.lock(db, LockType::Read).await?; } Ok(ModelData(db.get(self.0.as_ref()).await?)) } @@ -247,7 +255,7 @@ impl Deserialize<'de>> OptionModel { &self, db: &mut Db, ) -> Result>, Error> { - self.lock(db, LockType::Write).await; + self.lock(db, LockType::Write).await?; let original = db.get_value(self.0.as_ref(), None).await?; let current = serde_json::from_value(original.clone())?; Ok(ModelDataMut { @@ -259,7 +267,7 @@ impl Deserialize<'de>> OptionModel { pub async fn exists(&self, db: &mut Db, lock: bool) -> Result { if lock { - db.lock(self.0.as_ref().clone(), LockType::Exist).await; + db.lock(self.0.as_ref().clone(), LockType::Exist).await?; } Ok(db.exists(&self.as_ref(), None).await?) } @@ -285,7 +293,7 @@ impl Deserialize<'de>> OptionModel { } pub async fn delete(&self, db: &mut Db) -> Result>, Error> { - db.lock(self.as_ref().clone(), LockType::Write).await; + db.lock(self.as_ref().clone(), LockType::Write).await?; db.put(self.as_ref(), &Value::Null).await } } @@ -319,7 +327,7 @@ where db: &mut Db, value: &T, ) -> Result>, Error> { - db.lock(self.as_ref().clone(), LockType::Write).await; + db.lock(self.as_ref().clone(), LockType::Write).await?; db.put(self.as_ref(), value).await } } @@ -490,7 +498,7 @@ where lock: bool, ) -> Result, Error> { if lock { - db.lock(self.as_ref().clone(), LockType::Exist).await; + db.lock(self.as_ref().clone(), LockType::Exist).await?; } let set = db.keys(self.as_ref(), None).await?; Ok(set @@ -499,7 +507,7 @@ where .collect::>()?) } pub async fn remove(&self, db: &mut Db, key: &T::Key) -> Result<(), Error> { - db.lock(self.as_ref().clone(), LockType::Write).await; + db.lock(self.as_ref().clone(), LockType::Write).await?; if db.exists(self.clone().idx(key).as_ref(), None).await? { db.apply( DiffPatch(Patch(vec![PatchOperation::Remove(RemoveOperation { diff --git a/patch-db/src/transaction.rs b/patch-db/src/transaction.rs index 24937b6..acc794e 100644 --- a/patch-db/src/transaction.rs +++ b/patch-db/src/transaction.rs @@ -10,7 +10,7 @@ use tokio::sync::broadcast::Receiver; use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use crate::handle::HandleId; -use crate::locker::{Guard, LockType, Locker}; +use crate::locker::{Guard, LockError, LockType, Locker}; use crate::patch::{DiffPatch, Revision}; use crate::store::Store; use crate::{DbHandle, Error, PatchDbHandle}; @@ -165,13 +165,13 @@ impl DbHandle for Transaction { self.updates.append(patch); Ok(None) } - async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) { - self.locks.push( + async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> Result<(), LockError> { + Ok(self.locks.push( self.parent .locker() .lock(self.id.clone(), ptr, lock_type) - .await, - ) + .await?, + )) } async fn get< T: for<'de> Deserialize<'de>,