diff --git a/patch-db/src/handle.rs b/patch-db/src/handle.rs index ebf982e..5e148db 100644 --- a/patch-db/src/handle.rs +++ b/patch-db/src/handle.rs @@ -7,6 +7,7 @@ use serde_json::Value; use std::collections::BTreeSet; use tokio::sync::{broadcast::Receiver, RwLock, RwLockReadGuard}; +use crate::locker::LockType; use crate::{locker::Guard, Locker, PatchDb, Revision, Store, Transaction}; use crate::{patch::DiffPatch, Error}; @@ -42,7 +43,7 @@ pub trait DbHandle: Send + Sync { value: &Value, ) -> Result>, Error>; async fn apply(&mut self, patch: DiffPatch) -> Result>, Error>; - async fn lock(&mut self, ptr: JsonPointer, write: bool) -> (); + async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> (); async fn get< T: for<'de> Deserialize<'de>, S: AsRef + Send + Sync, @@ -124,8 +125,8 @@ impl DbHandle for &mut Handle { async fn apply(&mut self, patch: DiffPatch) -> Result>, Error> { (*self).apply(patch).await } - async fn lock(&mut self, ptr: JsonPointer, write: bool) { - (*self).lock(ptr, write).await + async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) { + (*self).lock(ptr, lock_type).await } async fn get< T: for<'de> Deserialize<'de>, @@ -232,9 +233,9 @@ impl DbHandle for PatchDbHandle { async fn apply(&mut self, patch: DiffPatch) -> Result>, Error> { self.db.apply(patch, None, None).await } - async fn lock(&mut self, ptr: JsonPointer, write: bool) { + async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) { self.locks - .push(self.db.locker.lock(self.id.clone(), ptr, write).await); + .push(self.db.locker.lock(self.id.clone(), ptr, lock_type).await); } async fn get< T: for<'de> Deserialize<'de>, diff --git a/patch-db/src/locker.rs b/patch-db/src/locker.rs index b3be181..32e32d4 100644 --- a/patch-db/src/locker.rs +++ b/patch-db/src/locker.rs @@ -36,14 +36,14 @@ impl Locker { }); Locker { sender } } - pub async fn lock(&self, handle_id: HandleId, ptr: JsonPointer, write: bool) -> Guard { + pub async fn lock(&self, handle_id: HandleId, ptr: JsonPointer, lock_type: LockType) -> Guard { let (send, recv) = oneshot::channel(); self.sender .send(Request { lock_info: LockInfo { handle_id, ptr, - write, + ty: lock_type, segments_handled: 0, }, completion: send, @@ -127,68 +127,141 @@ impl Trie { #[derive(Debug, Default)] struct Node { + reader_parents: Vec, readers: Vec, + writer_parents: Vec, writers: Vec, reqs: Vec, } impl Node { - // true: If there are any writers, it is `id`. - fn write_free(&self, id: &HandleId) -> bool { - self.writers.is_empty() || (self.writers.iter().filter(|a| a != &id).count() == 0) + // true: If there are any writer_parents, they are `id`. + fn write_parent_free(&self, id: &HandleId) -> bool { + self.writer_parents.is_empty() || (self.writer_parents.iter().find(|a| a != &id).is_none()) } - // true: If there are any readers, it is `id`. + // true: If there are any writers, they are `id`. + fn write_free(&self, id: &HandleId) -> bool { + self.writers.is_empty() || (self.writers.iter().find(|a| a != &id).is_none()) + } + // true: If there are any reader_parents, they are `id`. + fn read_parent_free(&self, id: &HandleId) -> bool { + self.reader_parents.is_empty() || (self.reader_parents.iter().find(|a| a != &id).is_none()) + } + // true: If there are any readers, they are `id`. fn read_free(&self, id: &HandleId) -> bool { - self.readers.is_empty() || (self.readers.iter().filter(|a| a != &id).count() == 0) + self.readers.is_empty() || (self.readers.iter().find(|a| a != &id).is_none()) } // allow a lock to skip the queue if a lock is already held by the same handle fn can_jump_queue(&self, id: &HandleId) -> bool { - self.writers.contains(&id) || self.readers.contains(&id) + self.writers.contains(&id) + || self.writer_parents.contains(&id) + || self.readers.contains(&id) + || self.reader_parents.contains(&id) } - // `id` is capable of acquiring this node for writing - fn write_available(&self, id: &HandleId) -> bool { + // `id` is capable of acquiring this node for the purpose of writing to a child + fn write_parent_available(&self, id: &HandleId) -> bool { self.write_free(id) && self.read_free(id) && (self.reqs.is_empty() || self.can_jump_queue(id)) } + // `id` is capable of acquiring this node for writing + fn write_available(&self, id: &HandleId) -> bool { + self.write_free(id) + && self.write_parent_free(id) + && self.read_free(id) + && self.read_parent_free(id) + && (self.reqs.is_empty() || self.can_jump_queue(id)) + } + fn read_parent_available(&self, id: &HandleId) -> bool { + self.write_free(id) && (self.reqs.is_empty() || self.can_jump_queue(id)) + } // `id` is capable of acquiring this node for reading fn read_available(&self, id: &HandleId) -> bool { - self.write_free(id) && (self.reqs.is_empty() || self.can_jump_queue(id)) + self.write_free(id) + && self.write_parent_free(id) + && (self.reqs.is_empty() || self.can_jump_queue(id)) } fn handle_request( &mut self, req: Request, locks_on_lease: &mut Vec>, ) -> Option { - if req.lock_info.write() && self.write_available(&req.lock_info.handle_id) { - self.writers.push(req.lock_info.handle_id.clone()); - req.process(locks_on_lease) - } else if !req.lock_info.write() && self.read_available(&req.lock_info.handle_id) { - self.readers.push(req.lock_info.handle_id.clone()); - req.process(locks_on_lease) - } else { - self.reqs.push(req); - None + match ( + req.lock_info.ty, + req.lock_info.segments_handled == req.lock_info.ptr.len(), + ) { + (LockType::Write, true) if self.write_available(&req.lock_info.handle_id) => { + self.writers.push(req.lock_info.handle_id.clone()); + req.process(locks_on_lease) + } + (LockType::DeepRead, true) if self.read_available(&req.lock_info.handle_id) => { + self.readers.push(req.lock_info.handle_id.clone()); + req.process(locks_on_lease) + } + (LockType::Write, false) if self.write_parent_available(&req.lock_info.handle_id) => { + self.writer_parents.push(req.lock_info.handle_id.clone()); + req.process(locks_on_lease) + } + (LockType::DeepRead, false) | (LockType::ShallowRead, _) + if self.read_parent_available(&req.lock_info.handle_id) => + { + self.reader_parents.push(req.lock_info.handle_id.clone()); + req.process(locks_on_lease) + } + _ => { + self.reqs.push(req); + None + } } } fn release(&mut self, mut lock_info: LockInfo) -> Option { - if lock_info.write() { - if let Some(idx) = self - .writers - .iter() - .enumerate() - .find(|(_, id)| id == &&lock_info.handle_id) - .map(|(idx, _)| idx) - { - self.writers.swap_remove(idx); + match ( + lock_info.ty, + lock_info.segments_handled == lock_info.ptr.len(), + ) { + (LockType::Write, true) => { + if let Some(idx) = self + .writers + .iter() + .enumerate() + .find(|(_, id)| id == &&lock_info.handle_id) + .map(|(idx, _)| idx) + { + self.writers.swap_remove(idx); + } + } + (LockType::DeepRead, true) => { + if let Some(idx) = self + .writers + .iter() + .enumerate() + .find(|(_, id)| id == &&lock_info.handle_id) + .map(|(idx, _)| idx) + { + self.readers.swap_remove(idx); + } + } + (LockType::Write, false) => { + if let Some(idx) = self + .writer_parents + .iter() + .enumerate() + .find(|(_, id)| id == &&lock_info.handle_id) + .map(|(idx, _)| idx) + { + self.writer_parents.swap_remove(idx); + } + } + (LockType::DeepRead, false) | (LockType::ShallowRead, _) => { + if let Some(idx) = self + .reader_parents + .iter() + .enumerate() + .find(|(_, id)| id == &&lock_info.handle_id) + .map(|(idx, _)| idx) + { + self.reader_parents.swap_remove(idx); + } } - } else if let Some(idx) = self - .readers - .iter() - .enumerate() - .find(|(_, id)| id == &&lock_info.handle_id) - .map(|(idx, _)| idx) - { - assert!(lock_info.handle_id == self.readers.swap_remove(idx)); } if lock_info.ptr.len() == lock_info.segments_handled { None @@ -203,13 +276,10 @@ impl Node { struct LockInfo { ptr: JsonPointer, segments_handled: usize, - write: bool, + ty: LockType, handle_id: HandleId, } impl LockInfo { - fn write(&self) -> bool { - self.write && self.segments_handled == self.ptr.len() - } fn current_seg(&self) -> &str { if self.segments_handled == 0 { "" // root @@ -225,6 +295,18 @@ impl LockInfo { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum LockType { + ShallowRead, + DeepRead, + Write, +} +impl Default for LockType { + fn default() -> Self { + LockType::ShallowRead + } +} + #[derive(Debug)] struct Request { lock_info: LockInfo, @@ -233,18 +315,21 @@ struct Request { impl Request { fn process(mut self, locks_on_lease: &mut Vec>) -> Option { if self.lock_info.ptr.len() == self.lock_info.segments_handled { - let (sender, receiver) = oneshot::channel(); - locks_on_lease.push(receiver); - let _ = self.completion.send(Guard { - lock_info: self.lock_info.reset(), - sender: Some(sender), - }); + self.complete(locks_on_lease); None } else { self.lock_info.segments_handled += 1; Some(self) } } + fn complete(self, locks_on_lease: &mut Vec>) { + let (sender, receiver) = oneshot::channel(); + locks_on_lease.push(receiver); + let _ = self.completion.send(Guard { + lock_info: self.lock_info.reset(), + sender: Some(sender), + }); + } } #[derive(Debug)] diff --git a/patch-db/src/model.rs b/patch-db/src/model.rs index 2e36787..d9eb1dc 100644 --- a/patch-db/src/model.rs +++ b/patch-db/src/model.rs @@ -9,6 +9,7 @@ use json_ptr::JsonPointer; use serde::{Deserialize, Serialize}; use serde_json::Value; +use crate::locker::LockType; use crate::{DbHandle, DiffPatch, Error, Revision}; #[derive(Debug)] @@ -63,19 +64,19 @@ impl Model where T: Serialize + for<'de> Deserialize<'de>, { - pub async fn lock(&self, db: &mut Db, write: bool) { - db.lock(self.ptr.clone(), write).await + pub async fn lock(&self, db: &mut Db, lock_type: LockType) { + db.lock(self.ptr.clone(), lock_type).await } pub async fn get(&self, db: &mut Db, lock: bool) -> Result, Error> { if lock { - self.lock(db, false).await; + self.lock(db, LockType::DeepRead).await; } Ok(ModelData(db.get(&self.ptr).await?)) } pub async fn get_mut(&self, db: &mut Db) -> Result, Error> { - self.lock(db, true).await; + self.lock(db, LockType::Write).await; let original = db.get_value(&self.ptr, None).await?; let current = serde_json::from_value(original.clone())?; Ok(ModelDataMut { @@ -103,7 +104,7 @@ where db: &mut Db, value: &T, ) -> Result>, Error> { - self.lock(db, true).await; + self.lock(db, LockType::Write).await; db.put(&self.ptr, value).await } } @@ -226,8 +227,8 @@ impl Deserialize<'de>> HasModel for Box { #[derive(Debug)] pub struct OptionModel Deserialize<'de>>(T::Model); impl Deserialize<'de>> OptionModel { - pub async fn lock(&self, db: &mut Db, write: bool) { - db.lock(self.0.as_ref().clone(), write).await + pub async fn lock(&self, db: &mut Db, lock_type: LockType) { + db.lock(self.0.as_ref().clone(), lock_type).await } pub async fn get( @@ -236,7 +237,7 @@ impl Deserialize<'de>> OptionModel { lock: bool, ) -> Result>, Error> { if lock { - self.lock(db, false).await; + self.lock(db, LockType::DeepRead).await; } Ok(ModelData(db.get(self.0.as_ref()).await?)) } @@ -245,7 +246,7 @@ impl Deserialize<'de>> OptionModel { &self, db: &mut Db, ) -> Result>, Error> { - self.lock(db, true).await; + self.lock(db, LockType::Write).await; let original = db.get_value(self.0.as_ref(), None).await?; let current = serde_json::from_value(original.clone())?; Ok(ModelDataMut { @@ -257,7 +258,8 @@ impl Deserialize<'de>> OptionModel { pub async fn exists(&self, db: &mut Db, lock: bool) -> Result { if lock { - db.lock(self.0.as_ref().clone(), false).await; + db.lock(self.0.as_ref().clone(), LockType::ShallowRead) + .await; } Ok(db.exists(&self.as_ref(), None).await?) } @@ -283,7 +285,7 @@ impl Deserialize<'de>> OptionModel { } pub async fn delete(&self, db: &mut Db) -> Result>, Error> { - db.lock(self.as_ref().clone(), true).await; + db.lock(self.as_ref().clone(), LockType::Write).await; db.put(self.as_ref(), &Value::Null).await } } @@ -317,7 +319,7 @@ where db: &mut Db, value: &T, ) -> Result>, Error> { - db.lock(self.as_ref().clone(), true).await; + db.lock(self.as_ref().clone(), LockType::Write).await; db.put(self.as_ref(), value).await } } @@ -488,7 +490,7 @@ where lock: bool, ) -> Result, Error> { if lock { - db.lock(self.as_ref().clone(), false).await; + db.lock(self.as_ref().clone(), LockType::ShallowRead).await; } let set = db.keys(self.as_ref(), None).await?; Ok(set @@ -497,13 +499,15 @@ where .collect::>()?) } pub async fn remove(&self, db: &mut Db, key: &T::Key) -> Result<(), Error> { - db.lock(self.as_ref().clone(), true).await; - db.apply(DiffPatch(Patch(vec![PatchOperation::Remove( - RemoveOperation { - path: self.as_ref().clone().join_end(key.as_ref()), - }, - )]))) - .await?; + db.lock(self.as_ref().clone(), LockType::Write).await; + if db.exists(self.clone().idx(key).as_ref(), None).await? { + db.apply(DiffPatch(Patch(vec![PatchOperation::Remove( + RemoveOperation { + path: self.as_ref().clone().join_end(key.as_ref()), + }, + )]))) + .await?; + } Ok(()) } } diff --git a/patch-db/src/transaction.rs b/patch-db/src/transaction.rs index c8dbdc2..44045d0 100644 --- a/patch-db/src/transaction.rs +++ b/patch-db/src/transaction.rs @@ -10,7 +10,7 @@ use tokio::sync::broadcast::Receiver; use tokio::sync::{RwLock, RwLockReadGuard}; use crate::handle::HandleId; -use crate::locker::{Guard, Locker}; +use crate::locker::{Guard, LockType, Locker}; use crate::patch::{DiffPatch, Revision}; use crate::store::Store; use crate::{DbHandle, Error, PatchDbHandle}; @@ -166,9 +166,13 @@ impl DbHandle for Transaction { self.updates.append(patch); Ok(None) } - async fn lock(&mut self, ptr: JsonPointer, write: bool) { - self.locks - .push(self.parent.locker().lock(self.id.clone(), ptr, write).await) + async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) { + self.locks.push( + self.parent + .locker() + .lock(self.id.clone(), ptr, lock_type) + .await, + ) } async fn get< T: for<'de> Deserialize<'de>,