From c2e50d0e88134f5871e7d64523a039563426122a Mon Sep 17 00:00:00 2001 From: Aiden McClelland Date: Tue, 9 Mar 2021 19:30:56 -0700 Subject: [PATCH] fix locking logic --- Cargo.toml | 6 + patch-db-derive-internals/Cargo.toml | 15 ++ patch-db-derive-internals/src/lib.rs | 42 +++++ patch-db-derive/Cargo.toml | 16 ++ patch-db-derive/src/lib.rs | 13 ++ patch-db/Cargo.toml | 2 +- patch-db/src/lib.rs | 255 ++++++++++++++++++++++----- 7 files changed, 305 insertions(+), 44 deletions(-) create mode 100644 Cargo.toml create mode 100644 patch-db-derive-internals/Cargo.toml create mode 100644 patch-db-derive-internals/src/lib.rs create mode 100644 patch-db-derive/Cargo.toml create mode 100644 patch-db-derive/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..ea93a0b --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,6 @@ +[workspace] +members = [ + "patch-db", + "patch-db-derive", + "patch-db-derive-internals", +] diff --git a/patch-db-derive-internals/Cargo.toml b/patch-db-derive-internals/Cargo.toml new file mode 100644 index 0000000..0c58efc --- /dev/null +++ b/patch-db-derive-internals/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "patch-db-derive-internals" +version = "0.1.0" +authors = ["Aiden McClelland "] +edition = "2018" +description = "internals for derive macros for defining typed patch dbs" +license = "MIT" +repository = "https://github.com/dr-bonez/patch-db" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +syn = { version = "1.0.5", features = ["full", "extra-traits"] } +quote = "1.0.1" +proc-macro2 = "1.0.1" diff --git a/patch-db-derive-internals/src/lib.rs b/patch-db-derive-internals/src/lib.rs new file mode 100644 index 0000000..638038f --- /dev/null +++ b/patch-db-derive-internals/src/lib.rs @@ -0,0 +1,42 @@ +use proc_macro2::TokenStream; +use quote::quote; + +pub fn build_model(input: &syn::DeriveInput) -> TokenStream { + match &input.data { + syn::Data::Struct(struct_ast) => build_model_struct(input, struct_ast), + syn::Data::Enum(enum_ast) => build_model_enum(enum_ast), + syn::Data::Union(_) => panic!("Unions are not supported"), + } +} + +fn build_model_struct(input: &syn::DeriveInput, ast: &syn::DataStruct) -> TokenStream { + let model_name = syn::Ident::new( + &format!("{}Model", input.ident), + proc_macro2::Span::call_site(), + ); + let base_name = &input.ident; + let model_vis = &input.vis; + quote! { + #model_vis struct #model_name { + data: Option>, + ptr: json_ptr::JsonPointer, + tx: Tx, + } + impl #model_name { + pub fn get(&mut self, lock: patch_db::LockType) -> Result<&#base_name, patch_db::Error> { + if let Some(data) = self.data.as_ref() { + match lock { + patch_db::LockType::None => Ok(data), + + } + } else { + self.tx.get(&self.ptr, lock) + } + } + } + } +} + +fn build_model_enum(ast: &syn::DataEnum) -> TokenStream { + todo!() +} diff --git a/patch-db-derive/Cargo.toml b/patch-db-derive/Cargo.toml new file mode 100644 index 0000000..b5577a6 --- /dev/null +++ b/patch-db-derive/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "patch-db-derive" +version = "0.1.0" +authors = ["Aiden McClelland "] +edition = "2018" +description = "derive macros for defining typed patch dbs" +license = "MIT" +repository = "https://github.com/dr-bonez/patch-db" + +[lib] +proc-macro = true +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +patch-db-derive-internals = { path = "../patch-db-derive-internals" } +syn = "1.0.62" diff --git a/patch-db-derive/src/lib.rs b/patch-db-derive/src/lib.rs new file mode 100644 index 0000000..be328bb --- /dev/null +++ b/patch-db-derive/src/lib.rs @@ -0,0 +1,13 @@ +extern crate proc_macro; + +use proc_macro::TokenStream; + +#[proc_macro_derive(Model, attributes(serde))] +pub fn model_derive(input: TokenStream) -> TokenStream { + // Construct a representation of Rust code as a syntax tree + // that we can manipulate + let ast = syn::parse(input).unwrap(); + + // Build the trait implementation + patch_db_derive_internals::build_model(&ast).into() +} diff --git a/patch-db/Cargo.toml b/patch-db/Cargo.toml index 4879aa4..36dda6b 100644 --- a/patch-db/Cargo.toml +++ b/patch-db/Cargo.toml @@ -12,7 +12,7 @@ fd-lock-rs = "0.1.3" futures = "0.3.8" json-patch = { path = "../../json-patch" } json-ptr = { path = "../../json-ptr" } -nix = "0.19.1" +nix = "0.20.0" qutex-2 = { path = "../../qutex" } serde = { version = "1.0.118", features = ["rc"] } serde_json = "1.0.61" diff --git a/patch-db/src/lib.rs b/patch-db/src/lib.rs index 9014a13..af70c22 100644 --- a/patch-db/src/lib.rs +++ b/patch-db/src/lib.rs @@ -16,7 +16,7 @@ use tokio::{ fs::File, sync::{ broadcast::{Receiver, Sender}, - RwLock, + Mutex, RwLock, }, }; @@ -228,9 +228,28 @@ pub trait Checkpoint { &'a self, ptr: &'a JsonPointer, ) -> BoxFuture>; - fn locks(&self) -> &[(JsonPointer, LockerGuard)]; - fn locker(&self) -> &Locker; + fn locker_and_locks(&mut self) -> (&Locker, Vec<&mut [(JsonPointer, LockerGuard)]>); fn apply(&mut self, patch: DiffPatch); + fn get< + 'a, + T: for<'de> Deserialize<'de> + 'a, + S: AsRef + Clone + Send + Sync + 'a, + V: SegList + Clone + Send + Sync + 'a, + >( + &'a mut self, + ptr: &'a JsonPointer, + lock: LockType, + ) -> BoxFuture<'a, Result>; + fn put< + 'a, + T: Serialize + Send + Sync + 'a, + S: AsRef + Send + Sync + 'a, + V: SegList + Send + Sync + 'a, + >( + &'a mut self, + ptr: &'a JsonPointer, + value: &'a T, + ) -> BoxFuture<'a, Result<(), Error>>; } pub struct Transaction { @@ -276,8 +295,18 @@ impl Transaction { ) { match lock { LockType::None => (), - LockType::Read => self.db.locker.add_read_lock(ptr, &mut self.locks).await, - LockType::Write => self.db.locker.add_write_lock(ptr, &mut self.locks).await, + LockType::Read => { + self.db + .locker + .add_read_lock(ptr, &mut self.locks, &mut []) + .await + } + LockType::Write => { + self.db + .locker + .add_write_lock(ptr, &mut self.locks, &mut []) + .await + } } } pub async fn get Deserialize<'de>, S: AsRef + Clone, V: SegList + Clone>( @@ -292,7 +321,7 @@ impl Transaction { } pub async fn put, V: SegList>( &mut self, - ptr: &JsonPointer, + ptr: &JsonPointer, value: &T, ) -> Result<(), Error> { let old = Transaction::get_value(self, ptr).await?; @@ -314,15 +343,36 @@ impl<'a> Checkpoint for &'a mut Transaction { ) -> BoxFuture<'b, Result> { Transaction::get_value(self, ptr).boxed() } - fn locks(&self) -> &[(JsonPointer, LockerGuard)] { - &self.locks - } - fn locker(&self) -> &Locker { - &self.db.locker + fn locker_and_locks(&mut self) -> (&Locker, Vec<&mut [(JsonPointer, LockerGuard)]>) { + (&self.db.locker, vec![&mut self.locks]) } fn apply(&mut self, patch: DiffPatch) { (self.updates.0).0.extend((patch.0).0) } + fn get< + 'b, + T: for<'de> Deserialize<'de> + 'b, + S: AsRef + Clone + Send + Sync + 'b, + V: SegList + Clone + Send + Sync + 'b, + >( + &'b mut self, + ptr: &'b JsonPointer, + lock: LockType, + ) -> BoxFuture<'b, Result> { + Transaction::get(self, ptr, lock).boxed() + } + fn put< + 'b, + T: Serialize + Send + Sync + 'b, + S: AsRef + Send + Sync + 'b, + V: SegList + Send + Sync + 'b, + >( + &'b mut self, + ptr: &'b JsonPointer, + value: &'b T, + ) -> BoxFuture<'b, Result<(), Error>> { + Transaction::put(self, ptr, value).boxed() + } } pub struct SubTransaction { @@ -363,23 +413,16 @@ impl SubTransaction { ptr: &JsonPointer, lock: LockType, ) { - for lock in self.locks.iter() { - if ptr.starts_with(&lock.0) { - return; - } - } match lock { LockType::None => (), LockType::Read => { - self.parent - .locker() - .add_read_lock(ptr, &mut self.locks) - .await + let (locker, mut locks) = self.parent.locker_and_locks(); + locker.add_read_lock(ptr, &mut self.locks, &mut locks).await } LockType::Write => { - self.parent - .locker() - .add_write_lock(ptr, &mut self.locks) + let (locker, mut locks) = self.parent.locker_and_locks(); + locker + .add_write_lock(ptr, &mut self.locks, &mut locks) .await } } @@ -400,7 +443,7 @@ impl SubTransaction { } pub async fn put + Send + Sync, V: SegList + Send + Sync>( &mut self, - ptr: &JsonPointer, + ptr: &JsonPointer, value: &T, ) -> Result<(), Error> { let old = SubTransaction::get_value(self, ptr).await?; @@ -422,15 +465,38 @@ impl<'a, Tx: Checkpoint + Send + Sync> Checkpoint for &'a mut SubTransaction ) -> BoxFuture<'b, Result> { SubTransaction::get_value(self, ptr).boxed() } - fn locks(&self) -> &[(JsonPointer, LockerGuard)] { - &self.locks - } - fn locker(&self) -> &Locker { - &self.parent.locker() + fn locker_and_locks(&mut self) -> (&Locker, Vec<&mut [(JsonPointer, LockerGuard)]>) { + let (locker, mut locks) = self.parent.locker_and_locks(); + locks.push(&mut self.locks); + (locker, locks) } fn apply(&mut self, patch: DiffPatch) { (self.updates.0).0.extend((patch.0).0) } + fn get< + 'b, + T: for<'de> Deserialize<'de> + 'b, + S: AsRef + Clone + Send + Sync + 'b, + V: SegList + Clone + Send + Sync + 'b, + >( + &'b mut self, + ptr: &'b JsonPointer, + lock: LockType, + ) -> BoxFuture<'b, Result> { + SubTransaction::get(self, ptr, lock).boxed() + } + fn put< + 'b, + T: Serialize + Send + Sync + 'b, + S: AsRef + Send + Sync + 'b, + V: SegList + Send + Sync + 'b, + >( + &'b mut self, + ptr: &'b JsonPointer, + value: &'b T, + ) -> BoxFuture<'b, Result<(), Error>> { + SubTransaction::put(self, ptr, value).boxed() + } } #[derive(Debug)] @@ -442,8 +508,8 @@ pub enum LockType { pub enum LockerGuard { Empty, - Read(ReadGuard>), - Write(WriteGuard>), + Read(LockerReadGuard), + Write(LockerWriteGuard), } impl LockerGuard { pub fn take(&mut self) -> Self { @@ -451,6 +517,44 @@ impl LockerGuard { } } +#[derive(Debug, Clone)] +pub struct LockerReadGuard(Arc>>>>); +impl LockerReadGuard { + async fn upgrade(&self) -> Option { + let guard = self.0.try_lock().unwrap().take(); + if let Some(g) = guard { + Some(LockerWriteGuard( + Some(ReadGuard::upgrade(g).await.unwrap()), + Some(self.clone()), + )) + } else { + None + } + } +} +impl From>> for LockerReadGuard { + fn from(guard: ReadGuard>) -> Self { + LockerReadGuard(Arc::new(Mutex::new(Some(guard)))) + } +} + +pub struct LockerWriteGuard( + Option>>, + Option, +); +impl From>> for LockerWriteGuard { + fn from(guard: WriteGuard>) -> Self { + LockerWriteGuard(Some(guard), None) + } +} +impl Drop for LockerWriteGuard { + fn drop(&mut self) { + if let (Some(write), Some(read)) = (self.0.take(), self.1.take()) { + *read.0.try_lock().unwrap() = Some(WriteGuard::downgrade(write)); + } + } +} + #[derive(Clone, Debug)] pub struct Locker(QrwLock>); impl Locker { @@ -475,24 +579,29 @@ impl Locker { } lock.unwrap() } - pub async fn add_read_lock + Clone, V: SegList + Clone>( + async fn add_read_lock + Clone, V: SegList + Clone>( &self, ptr: &JsonPointer, locks: &mut Vec<(JsonPointer, LockerGuard)>, + extra_locks: &mut [&mut [(JsonPointer, LockerGuard)]], ) { - for lock in locks.iter() { + for lock in extra_locks + .iter() + .flat_map(|a| a.iter()) + .chain(locks.iter()) + { if ptr.starts_with(&lock.0) { return; } } locks.push(( JsonPointer::to_owned(ptr.clone()), - LockerGuard::Read(self.lock_read(ptr).await), + LockerGuard::Read(self.lock_read(ptr).await.into()), )); } pub async fn lock_write, V: SegList>( &self, - ptr: &JsonPointer, + ptr: &JsonPointer, ) -> WriteGuard> { let mut lock = self.0.clone().write().await.unwrap(); for seg in ptr.iter() { @@ -506,26 +615,86 @@ impl Locker { } lock } - pub async fn add_write_lock + Clone, V: SegList + Clone>( + async fn add_write_lock + Clone, V: SegList + Clone>( &self, ptr: &JsonPointer, locks: &mut Vec<(JsonPointer, LockerGuard)>, + extra_locks: &mut [&mut [(JsonPointer, LockerGuard)]], ) { - for lock in locks.iter_mut() { - if ptr.starts_with(&lock.0) { + let mut final_lock = None; + for lock in extra_locks + .iter_mut() + .flat_map(|a| a.iter_mut()) + .chain(locks.iter_mut()) + { + enum Choice { + Return, + Continue, + Break, + } + let choice: Choice; + if let Some(remainder) = ptr.strip_prefix(&lock.0) { let guard = lock.1.take(); lock.1 = match guard { - LockerGuard::Read(l) => { - LockerGuard::Write(ReadGuard::upgrade(l).await.unwrap()) + LockerGuard::Read(LockerReadGuard(guard)) if !remainder.is_empty() => { + // read guard already exists at higher level + let mut lock = guard.lock().await; + if let Some(l) = lock.take() { + let mut orig_lock = None; + let mut lock = ReadGuard::upgrade(l).await.unwrap(); + for seg in remainder.iter() { + let new_lock = if let Some(locker) = lock.get(seg) { + locker.0.clone().write().await.unwrap() + } else { + lock.insert(seg.to_owned(), Locker::new()); + lock.get(seg).unwrap().0.clone().write().await.unwrap() + }; + if orig_lock.is_none() { + orig_lock = Some(lock); + } + lock = new_lock; + } + final_lock = Some(LockerGuard::Write(lock.into())); + choice = Choice::Break; + LockerGuard::Read(WriteGuard::downgrade(orig_lock.unwrap()).into()) + } else { + drop(lock); + choice = Choice::Return; + LockerGuard::Read(LockerReadGuard(guard)) + } + } + LockerGuard::Read(l) => { + // read exists, convert to write + if let Some(upgraded) = l.upgrade().await { + final_lock = Some(LockerGuard::Write(upgraded)); + choice = Choice::Break; + } else { + choice = Choice::Continue; + } + LockerGuard::Read(l) + } + LockerGuard::Write(l) => { + choice = Choice::Return; + LockerGuard::Write(l) + } // leave it alone, already sufficiently locked + LockerGuard::Empty => { + unreachable!("LockerGuard found empty"); } - a => a, }; - return; + match choice { + Choice::Return => return, + Choice::Break => break, + Choice::Continue => continue, + } } } locks.push(( JsonPointer::to_owned(ptr.clone()), - LockerGuard::Read(self.lock_read(ptr).await), + if let Some(lock) = final_lock { + lock + } else { + LockerGuard::Write(self.lock_write(ptr).await.into()) + }, )); } }