diff --git a/patch-db/src/handle.rs b/patch-db/src/handle.rs index cf03670..a234f85 100644 --- a/patch-db/src/handle.rs +++ b/patch-db/src/handle.rs @@ -14,7 +14,7 @@ use crate::{ use crate::{patch::DiffPatch, Error}; #[async_trait] -pub trait DbHandle: Sized + Send + Sync { +pub trait DbHandle: Send + Sync { async fn begin<'a>(&'a mut self) -> Result, Error>; fn rebase(&mut self) -> Result<(), Error>; fn store(&self) -> Arc>; @@ -45,6 +45,7 @@ pub trait DbHandle: Sized + Send + Sync { &mut self, ptr: &JsonPointer, lock: LockType, + deep: bool, ) -> (); async fn get< T: for<'de> Deserialize<'de>, @@ -65,7 +66,7 @@ pub trait DbHandle: Sized + Send + Sync { ) -> Result<(), Error>; } #[async_trait] -impl DbHandle for &mut Handle { +impl DbHandle for &mut Handle { async fn begin<'a>(&'a mut self) -> Result, Error> { let Transaction { locks, @@ -127,8 +128,9 @@ impl DbHandle for &mut Handle { &mut self, ptr: &JsonPointer, lock: LockType, + deep: bool, ) { - (*self).lock(ptr, lock).await + (*self).lock(ptr, lock, deep).await } async fn get< T: for<'de> Deserialize<'de>, @@ -229,18 +231,19 @@ impl DbHandle for PatchDbHandle { &mut self, ptr: &JsonPointer, lock: LockType, + deep: bool, ) { match lock { LockType::Read => { self.db .locker - .add_read_lock(ptr, &mut self.locks, &mut []) + .add_read_lock(ptr, &mut self.locks, &mut [], deep) .await; } LockType::Write => { self.db .locker - .add_write_lock(ptr, &mut self.locks, &mut []) + .add_write_lock(ptr, &mut self.locks, &mut [], deep) .await; } LockType::None => (), diff --git a/patch-db/src/locker.rs b/patch-db/src/locker.rs index 17de1a9..e59079a 100644 --- a/patch-db/src/locker.rs +++ b/patch-db/src/locker.rs @@ -80,6 +80,7 @@ impl Locker { pub async fn lock_read, V: SegList>( &self, ptr: &JsonPointer, + deep: bool, ) -> ReadGuard> { let mut lock = Some(self.0.clone().read().await.unwrap()); for seg in ptr.iter() { @@ -94,7 +95,9 @@ impl Locker { lock = Some(new_lock); } let res = lock.unwrap(); - Self::lock_root_read(&res); + if deep { + Self::lock_root_read(&res); + } res } pub(crate) async fn add_read_lock + Clone, V: SegList + Clone>( @@ -102,6 +105,7 @@ impl Locker { ptr: &JsonPointer, locks: &mut Vec<(JsonPointer, LockerGuard)>, extra_locks: &mut [&mut [(JsonPointer, LockerGuard)]], + deep: bool, ) { for lock in extra_locks .iter() @@ -114,7 +118,7 @@ impl Locker { } locks.push(( JsonPointer::to_owned(ptr.clone()), - LockerGuard::Read(self.lock_read(ptr).await.into()), + LockerGuard::Read(self.lock_read(ptr, deep).await.into()), )); } fn lock_root_write<'a>(guard: &'a WriteGuard>) -> BoxFuture<'a, ()> { @@ -129,6 +133,7 @@ impl Locker { pub async fn lock_write, V: SegList>( &self, ptr: &JsonPointer, + deep: bool, ) -> WriteGuard> { let mut lock = self.0.clone().write().await.unwrap(); for seg in ptr.iter() { @@ -141,7 +146,9 @@ impl Locker { lock = new_lock; } let res = lock; - Self::lock_root_write(&res); + if deep { + Self::lock_root_write(&res); + } res } pub(crate) async fn add_write_lock + Clone, V: SegList + Clone>( @@ -149,6 +156,7 @@ impl Locker { ptr: &JsonPointer, locks: &mut Vec<(JsonPointer, LockerGuard)>, // tx locks extra_locks: &mut [&mut [(JsonPointer, LockerGuard)]], // tx parent locks + deep: bool, ) { let mut final_lock = None; for lock in extra_locks @@ -222,7 +230,7 @@ impl Locker { if let Some(lock) = final_lock { lock } else { - LockerGuard::Write(self.lock_write(ptr).await.into()) + LockerGuard::Write(self.lock_write(ptr, deep).await.into()) }, )); } diff --git a/patch-db/src/model.rs b/patch-db/src/model.rs index fdc1137..3625571 100644 --- a/patch-db/src/model.rs +++ b/patch-db/src/model.rs @@ -64,11 +64,13 @@ where T: Serialize + for<'de> Deserialize<'de>, { pub async fn lock(&self, db: &mut Db, lock: LockType) { - db.lock(&self.ptr, lock).await + db.lock(&self.ptr, lock, true).await } - pub async fn get(&self, db: &mut Db) -> Result, Error> { - self.lock(db, LockType::Read).await; + pub async fn get(&self, db: &mut Db, lock: bool) -> Result, Error> { + if lock { + self.lock(db, LockType::Read).await; + } Ok(ModelData(db.get(&self.ptr).await?)) } @@ -202,11 +204,17 @@ impl Deserialize<'de>> HasModel for Box { pub struct OptionModel Deserialize<'de>>(T::Model); impl Deserialize<'de>> OptionModel { pub async fn lock(&self, db: &mut Db, lock: LockType) { - db.lock(self.0.as_ref(), lock).await + db.lock(self.0.as_ref(), lock, true).await } - pub async fn get(&self, db: &mut Db) -> Result>, Error> { - self.lock(db, LockType::Read).await; + pub async fn get( + &self, + db: &mut Db, + lock: bool, + ) -> Result>, Error> { + if lock { + self.lock(db, LockType::Read).await; + } Ok(ModelData(db.get(self.0.as_ref()).await?)) } @@ -224,8 +232,10 @@ impl Deserialize<'de>> OptionModel { }) } - pub async fn exists(&self, db: &mut Db) -> Result { - self.lock(db, LockType::Read).await; + pub async fn exists(&self, db: &mut Db, lock: bool) -> Result { + if lock { + db.lock(self.0.as_ref(), LockType::Read, false).await; + } Ok(db.exists(&self.as_ref(), None).await?) } @@ -252,7 +262,7 @@ impl Deserialize<'de>> OptionModel { } pub async fn check(self, db: &mut Db) -> Result, Error> { - Ok(if self.exists(db).await? { + Ok(if self.exists(db, true).await? { Some(self.0) } else { None @@ -260,7 +270,7 @@ impl Deserialize<'de>> OptionModel { } pub async fn expect(self, db: &mut Db) -> Result { - if self.exists(db).await? { + if self.exists(db, true).await? { Ok(self.0) } else { Err(Error::NodeDoesNotExist(self.0.into())) @@ -268,7 +278,7 @@ impl Deserialize<'de>> OptionModel { } pub async fn delete(&self, db: &mut Db) -> Result<(), Error> { - db.lock(self.as_ref(), LockType::Write).await; + db.lock(self.as_ref(), LockType::Write, true).await; db.put(self.as_ref(), &Value::Null).await } } @@ -277,7 +287,7 @@ where T: Serialize + for<'de> Deserialize<'de> + Send + Sync + HasModel, { pub async fn put(&self, db: &mut Db, value: &T) -> Result<(), Error> { - db.lock(self.as_ref(), LockType::Write).await; + db.lock(self.as_ref(), LockType::Write, true).await; db.put(self.as_ref(), value).await } } @@ -435,8 +445,14 @@ where T::Key: Hash + Eq + for<'de> Deserialize<'de>, T::Value: Serialize + for<'de> Deserialize<'de>, { - pub async fn keys(&self, db: &mut Db) -> Result, Error> { - db.lock(self.as_ref(), LockType::Read).await; + pub async fn keys( + &self, + db: &mut Db, + lock: bool, + ) -> Result, Error> { + if lock { + db.lock(self.as_ref(), LockType::Read, false).await; + } let set = db.keys(self.as_ref(), None).await?; Ok(set .into_iter() diff --git a/patch-db/src/transaction.rs b/patch-db/src/transaction.rs index 95a72cd..f8679d8 100644 --- a/patch-db/src/transaction.rs +++ b/patch-db/src/transaction.rs @@ -159,17 +159,20 @@ impl DbHandle for Transaction { &mut self, ptr: &JsonPointer, lock: LockType, + deep: bool, ) { match lock { LockType::None => (), LockType::Read => { let (locker, mut locks) = self.parent.locker_and_locks(); - locker.add_read_lock(ptr, &mut self.locks, &mut locks).await + locker + .add_read_lock(ptr, &mut self.locks, &mut locks, deep) + .await } LockType::Write => { let (locker, mut locks) = self.parent.locker_and_locks(); locker - .add_write_lock(ptr, &mut self.locks, &mut locks) + .add_write_lock(ptr, &mut self.locks, &mut locks, deep) .await } }