From 35973d7aef054842faa13c82f357252563108949 Mon Sep 17 00:00:00 2001 From: J M <2364004+Blu-J@users.noreply.github.com> Date: Mon, 9 May 2022 14:19:08 -0600 Subject: [PATCH] Feat new locking (#31) This converts the old singular locks into plural locks. --- patch-db-macro-internals/src/lib.rs | 311 +++++++------ .../proptest-regressions/locker/proptest.txt | 7 + patch-db/proptest-regressions/model_paths.txt | 7 + patch-db/src/bulk_locks.rs | 167 +++++++ patch-db/src/bulk_locks/unsaturated_args.rs | 70 +++ patch-db/src/handle.rs | 147 +++++- patch-db/src/lib.rs | 12 + patch-db/src/locker/action_mux.rs | 16 +- patch-db/src/locker/bookkeeper.rs | 123 +++-- patch-db/src/locker/mod.rs | 209 ++++++++- patch-db/src/locker/order_enforcer.rs | 58 ++- patch-db/src/locker/proptest.rs | 113 +++-- patch-db/src/locker/trie.rs | 261 +++++++---- patch-db/src/model.rs | 225 ++++++++-- patch-db/src/model_paths.rs | 419 ++++++++++++++++++ patch-db/src/test.rs | 4 +- patch-db/src/transaction.rs | 26 +- 17 files changed, 1795 insertions(+), 380 deletions(-) create mode 100644 patch-db/proptest-regressions/locker/proptest.txt create mode 100644 patch-db/proptest-regressions/model_paths.txt create mode 100644 patch-db/src/bulk_locks.rs create mode 100644 patch-db/src/bulk_locks/unsaturated_args.rs create mode 100644 patch-db/src/model_paths.rs diff --git a/patch-db-macro-internals/src/lib.rs b/patch-db-macro-internals/src/lib.rs index 391f2a5..58d6f77 100644 --- a/patch-db-macro-internals/src/lib.rs +++ b/patch-db-macro-internals/src/lib.rs @@ -101,159 +101,176 @@ fn build_model_struct( _ => None, }); - child_path.push(Some( - match (serde_rename, serde_rename_all.as_ref().map(|s| s.as_str())) { - (Some(a), _) => a, - (None, Some("lowercase")) => LitStr::new( - &heck::CamelCase::to_camel_case(ident.to_string().as_str()) - .to_lowercase(), - ident.span(), + child_path.push(Some(match (serde_rename, serde_rename_all.as_deref()) { + (Some(a), _) => a, + (None, Some("lowercase")) => LitStr::new( + &heck::CamelCase::to_camel_case(ident.to_string().as_str()) + .to_lowercase(), + ident.span(), + ), + (None, Some("UPPERCASE")) => LitStr::new( + &heck::CamelCase::to_camel_case(ident.to_string().as_str()) + .to_uppercase(), + ident.span(), + ), + (None, Some("PascalCase")) => LitStr::new( + &heck::CamelCase::to_camel_case(ident.to_string().as_str()), + ident.span(), + ), + (None, Some("camelCase")) => LitStr::new( + &heck::MixedCase::to_mixed_case(ident.to_string().as_str()), + ident.span(), + ), + (None, Some("SCREAMING_SNAKE_CASE")) => LitStr::new( + &heck::ShoutySnakeCase::to_shouty_snake_case( + ident.to_string().as_str(), ), - (None, Some("UPPERCASE")) => LitStr::new( - &heck::CamelCase::to_camel_case(ident.to_string().as_str()) - .to_uppercase(), - ident.span(), + ident.span(), + ), + (None, Some("kebab-case")) => LitStr::new( + &heck::KebabCase::to_kebab_case(ident.to_string().as_str()), + ident.span(), + ), + (None, Some("SCREAMING-KEBAB-CASE")) => LitStr::new( + &heck::ShoutyKebabCase::to_shouty_kebab_case( + ident.to_string().as_str(), ), - (None, Some("PascalCase")) => LitStr::new( - &heck::CamelCase::to_camel_case(ident.to_string().as_str()), - ident.span(), - ), - (None, Some("camelCase")) => LitStr::new( - &heck::MixedCase::to_mixed_case(ident.to_string().as_str()), - ident.span(), - ), - (None, Some("SCREAMING_SNAKE_CASE")) => LitStr::new( - &heck::ShoutySnakeCase::to_shouty_snake_case( - ident.to_string().as_str(), - ), - ident.span(), - ), - (None, Some("kebab-case")) => LitStr::new( - &heck::KebabCase::to_kebab_case(ident.to_string().as_str()), - ident.span(), - ), - (None, Some("SCREAMING-KEBAB-CASE")) => LitStr::new( - &heck::ShoutyKebabCase::to_shouty_kebab_case( - ident.to_string().as_str(), - ), - ident.span(), - ), - _ => LitStr::new(&ident.to_string(), ident.span()), - }, - )); + ident.span(), + ), + _ => LitStr::new(&ident.to_string(), ident.span()), + })); } } } - Fields::Unnamed(f) => { - if f.unnamed.len() == 1 { - // newtype wrapper - let field = &f.unnamed[0]; + Fields::Unnamed(f) if f.unnamed.len() == 1 => { + // newtype wrapper + let field = &f.unnamed[0]; + let ty = &field.ty; + let inner_model: Type = if let Some(child_model_name) = field + .attrs + .iter() + .filter(|attr| attr.path.is_ident("model")) + .map(|attr| attr.parse_args::().unwrap()) + .filter(|nv| nv.path.is_ident("name")) + .find_map(|nv| match nv.lit { + Lit::Str(s) => Some(s), + _ => None, + }) { + let child_model_ty = Ident::new(&child_model_name.value(), child_model_name.span()); + syn::parse2(quote! { #child_model_ty }).unwrap() + } else if field.attrs.iter().any(|attr| attr.path.is_ident("model")) { + syn::parse2(quote! { <#ty as patch_db::HasModel>::Model }).unwrap() + } else { + syn::parse2(quote! { patch_db::Model::<#ty> }).unwrap() + }; + let result = quote! { + #[derive(Debug)] + #model_vis struct #model_name(#inner_model); + impl #model_name { + pub fn into_model(self) -> patch_db::Model<#base_name> { + self.into() + } + } + impl std::clone::Clone for #model_name { + fn clone(&self) -> Self { + #model_name(self.0.clone()) + } + } + impl core::ops::Deref for #model_name { + type Target = #inner_model; + fn deref(&self) -> &Self::Target { + &self.0 + } + } + impl core::ops::DerefMut for #model_name { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + impl From for #model_name { + fn from(ptr: patch_db::json_ptr::JsonPointer) -> Self { + #model_name(#inner_model::from(ptr)) + } + } + impl From for #model_name { + fn from(ptr: patch_db::JsonGlob) -> Self { + #model_name(#inner_model::from(ptr)) + } + } + impl From<#model_name> for patch_db::Model<#base_name> { + fn from(model: #model_name) -> Self { + patch_db::Model::from(patch_db::JsonGlob::from(model)) + } + } + impl From<#model_name> for patch_db::json_ptr::JsonPointer { + fn from(model: #model_name) -> Self { + model.0.into() + } + } + impl From<#model_name> for patch_db::JsonGlob { + fn from(model: #model_name) -> Self { + model.0.into() + } + } + impl AsRef for #model_name { + fn as_ref(&self) -> &patch_db::json_ptr::JsonPointer { + self.0.as_ref() + } + } + impl From> for #model_name { + fn from(model: patch_db::Model<#base_name>) -> Self { + #model_name(#inner_model::from(patch_db::json_ptr::JsonPointer::from(model))) + } + } + impl From<#inner_model> for #model_name { + fn from(model: #inner_model) -> Self { + #model_name(model) + } + } + impl patch_db::HasModel for #base_name { + type Model = #model_name; + } + }; + // panic!("{}", result); + return result; + } + Fields::Unnamed(f) if f.unnamed.len() > 1 => { + for (i, field) in f.unnamed.iter().enumerate() { + child_fn_name.push(Ident::new( + &format!("idx_{}", i), + proc_macro2::Span::call_site(), + )); let ty = &field.ty; - let inner_model: Type = if let Some(child_model_name) = field + if let Some(child_model_name) = field .attrs .iter() .filter(|attr| attr.path.is_ident("model")) - .filter_map(|attr| Some(attr.parse_args::().unwrap())) + .map(|attr| attr.parse_args::().unwrap()) .filter(|nv| nv.path.is_ident("name")) .find_map(|nv| match nv.lit { Lit::Str(s) => Some(s), _ => None, - }) { + }) + { let child_model_ty = Ident::new(&child_model_name.value(), child_model_name.span()); - syn::parse2(quote! { #child_model_ty }).unwrap() + child_model + .push(syn::parse2(quote! { #child_model_ty }).expect("invalid model name")); } else if field.attrs.iter().any(|attr| attr.path.is_ident("model")) { - syn::parse2(quote! { <#ty as patch_db::HasModel>::Model }).unwrap() + child_model + .push(syn::parse2(quote! { <#ty as patch_db::HasModel>::Model }).unwrap()); } else { - syn::parse2(quote! { patch_db::Model::<#ty> }).unwrap() - }; - return quote! { - #[derive(Debug)] - #model_vis struct #model_name(#inner_model); - impl std::clone::Clone for #model_name { - fn clone(&self) -> Self { - #model_name(self.0.clone()) - } - } - impl core::ops::Deref for #model_name { - type Target = #inner_model; - fn deref(&self) -> &Self::Target { - &self.0 - } - } - impl core::ops::DerefMut for #model_name { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } - } - impl From for #model_name { - fn from(ptr: patch_db::json_ptr::JsonPointer) -> Self { - #model_name(#inner_model::from(ptr)) - } - } - impl From<#model_name> for patch_db::json_ptr::JsonPointer { - fn from(model: #model_name) -> Self { - model.0.into() - } - } - impl AsRef for #model_name { - fn as_ref(&self) -> &patch_db::json_ptr::JsonPointer { - self.0.as_ref() - } - } - impl From> for #model_name { - fn from(model: patch_db::Model<#base_name>) -> Self { - #model_name(#inner_model::from(patch_db::json_ptr::JsonPointer::from(model))) - } - } - impl From<#inner_model> for #model_name { - fn from(model: #inner_model) -> Self { - #model_name(model) - } - } - impl patch_db::HasModel for #base_name { - type Model = #model_name; - } - }; - } else if f.unnamed.len() > 1 { - for (i, field) in f.unnamed.iter().enumerate() { - child_fn_name.push(Ident::new( - &format!("idx_{}", i), - proc_macro2::Span::call_site(), - )); - let ty = &field.ty; - if let Some(child_model_name) = field - .attrs - .iter() - .filter(|attr| attr.path.is_ident("model")) - .filter_map(|attr| Some(attr.parse_args::().unwrap())) - .filter(|nv| nv.path.is_ident("name")) - .find_map(|nv| match nv.lit { - Lit::Str(s) => Some(s), - _ => None, - }) - { - let child_model_ty = - Ident::new(&child_model_name.value(), child_model_name.span()); - child_model.push( - syn::parse2(quote! { #child_model_ty }).expect("invalid model name"), - ); - } else if field.attrs.iter().any(|attr| attr.path.is_ident("model")) { - child_model.push( - syn::parse2(quote! { <#ty as patch_db::HasModel>::Model }).unwrap(), - ); - } else { - child_model.push(syn::parse2(quote! { patch_db::Model<#ty> }).unwrap()); - } - // TODO: serde rename for tuple structs? - // TODO: serde flatten for tuple structs? - child_path.push(Some(LitStr::new( - &format!("{}", i), - proc_macro2::Span::call_site(), - ))); + child_model.push(syn::parse2(quote! { patch_db::Model<#ty> }).unwrap()); } + // TODO: serde rename for tuple structs? + // TODO: serde flatten for tuple structs? + child_path.push(Some(LitStr::new( + &format!("{}", i), + proc_macro2::Span::call_site(), + ))); } } + Fields::Unnamed(_f) => {} Fields::Unit => (), } let child_path_expr = child_path.iter().map(|child_path| { @@ -292,17 +309,35 @@ fn build_model_struct( #child_path_expr } )* + pub fn into_model(self) -> patch_db::Model<#base_name> { + self.into() + } } impl From for #model_name { fn from(ptr: patch_db::json_ptr::JsonPointer) -> Self { #model_name(From::from(ptr)) } } + impl From for #model_name { + fn from(ptr: patch_db::JsonGlob) -> Self { + #model_name(From::from(ptr)) + } + } + impl From<#model_name> for patch_db::Model<#base_name> { + fn from(model: #model_name) -> Self { + model.0 + } + } impl From<#model_name> for patch_db::json_ptr::JsonPointer { fn from(model: #model_name) -> Self { model.0.into() } } + impl From<#model_name> for patch_db::JsonGlob { + fn from(model: #model_name) -> Self { + model.0.into() + } + } impl AsRef for #model_name { fn as_ref(&self) -> &patch_db::json_ptr::JsonPointer { self.0.as_ref() @@ -347,6 +382,16 @@ fn build_model_enum(base: &DeriveInput, _: &DataEnum, model_name: Option) #model_name(From::from(ptr)) } } + impl From for #model_name { + fn from(ptr: patch_db::JsonGlob) -> Self { + #model_name(From::from(ptr)) + } + } + impl From<#model_name> for patch_db::JsonGlob { + fn from(model: #model_name) -> Self { + model.0.into() + } + } impl From<#model_name> for patch_db::json_ptr::JsonPointer { fn from(model: #model_name) -> Self { model.0.into() diff --git a/patch-db/proptest-regressions/locker/proptest.txt b/patch-db/proptest-regressions/locker/proptest.txt new file mode 100644 index 0000000..39df059 --- /dev/null +++ b/patch-db/proptest-regressions/locker/proptest.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc b68e3527ee511046c2e96c9b75163422af755c5204f4a3cfaa973b8293bbf961 # shrinks to lock_order = [LockInfo { handle_id: HandleId { id: 0 }, ptr: PossibleStarPath { path: [Star], count: 1 }, ty: Exist }] diff --git a/patch-db/proptest-regressions/model_paths.txt b/patch-db/proptest-regressions/model_paths.txt new file mode 100644 index 0000000..0d9b334 --- /dev/null +++ b/patch-db/proptest-regressions/model_paths.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 32f96b89f02e5b9dfcc4de5a76c05a2a5e9ef1b232ea9cb5953a29169bec3b16 # shrinks to left = "", right = "" diff --git a/patch-db/src/bulk_locks.rs b/patch-db/src/bulk_locks.rs new file mode 100644 index 0000000..7362592 --- /dev/null +++ b/patch-db/src/bulk_locks.rs @@ -0,0 +1,167 @@ +use std::marker::PhantomData; + +use imbl::OrdSet; +use serde::{Deserialize, Serialize}; + +use crate::{model_paths::JsonGlob, DbHandle, Error, LockType}; + +use self::unsaturated_args::UnsaturatedArgs; + +pub mod unsaturated_args; + +/// Used at the beggining of a set of code that may acquire locks into a db. +/// This will be used to represent a potential lock that would be used, and this will then be +/// sent to a bulk locker, that will take multiple of these targets and lock them all at once instead +/// of one at a time. Then once the locks have been acquired, this target can then be turned into a receipt +/// which can then access into the db. +#[derive(Clone)] +pub struct LockTarget +where + T: Serialize + for<'de> Deserialize<'de>, +{ + pub glob: JsonGlob, + pub lock_type: LockType, + /// What the target will eventually need to return in a get, or value to be put in a set + pub(crate) db_type: PhantomData, + /// How many stars (potential keys in maps, ...) that need to be bound to actual paths. + pub(crate) _star_binds: UnsaturatedArgs, +} + +/// This is acting as a newtype for the copyable section +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] +pub struct LockTargetId { + pub(crate) glob: JsonGlob, + pub(crate) lock_type: LockType, +} + +impl LockTarget +where + T: Serialize + for<'de> Deserialize<'de>, +{ + pub fn key_for_indexing(&self) -> LockTargetId { + let paths: &JsonGlob = &self.glob; + LockTargetId { + // TODO: Remove this clone + glob: paths.clone(), + lock_type: self.lock_type, + } + } + + pub fn add_to_keys(self, locks: &mut Vec) -> Self { + locks.push(self.key_for_indexing()); + self + } +} + +#[derive(Debug, Clone)] +pub struct Verifier { + pub(crate) target_locks: OrdSet, +} + +impl LockTarget +where + T: Serialize + for<'de> Deserialize<'de>, +{ + /// Use this to verify the target, and if valid return a verified lock + pub fn verify(self, lock_set: &Verifier) -> Result, Error> { + if !lock_set.target_locks.contains(&self.key_for_indexing()) { + return Err(Error::Locker( + "Cannot unlock a lock that is not in the unlock set".to_string(), + )); + } + Ok(LockReceipt { lock: self }) + } +} + +/// A lock reciept is the final goal, where we can now get/ set into the db +#[derive(Clone)] +pub struct LockReceipt +where + T: Serialize + for<'de> Deserialize<'de>, +{ + pub lock: LockTarget, +} + +impl LockReceipt +where + T: Serialize + for<'de> Deserialize<'de> + Send + Sync, +{ + async fn set_( + &self, + db_handle: &mut DH, + new_value: T, + binds: &[&str], + ) -> Result<(), Error> { + let lock_type = self.lock.lock_type; + let pointer = &self.lock.glob.as_pointer(binds); + if lock_type != LockType::Write { + return Err(Error::Locker("Cannot set a read lock".to_string())); + } + + db_handle.put(pointer, &new_value).await?; + Ok(()) + } + async fn get_( + &self, + db_handle: &mut DH, + binds: &[&str], + ) -> Result, Error> { + let path = self.lock.glob.as_pointer(binds); + if !db_handle.exists(&path, None).await? { + return Ok(None); + } + Ok(Some(db_handle.get(&path).await?)) + } +} +impl LockReceipt +where + T: Serialize + for<'de> Deserialize<'de> + Send + Sync, +{ + pub async fn set(&self, db_handle: &mut DH, new_value: T) -> Result<(), Error> { + self.set_(db_handle, new_value, &[]).await + } + pub async fn get(&self, db_handle: &mut DH) -> Result { + self.get_(db_handle, &[]).await.map(|x| x.unwrap()) + } +} +impl LockReceipt +where + T: Serialize + for<'de> Deserialize<'de> + Send + Sync, +{ + pub async fn set( + &self, + db_handle: &mut DH, + new_value: T, + binds: &str, + ) -> Result<(), Error> { + self.set_(db_handle, new_value, &[binds]).await + } + pub async fn get( + &self, + db_handle: &mut DH, + binds: &str, + ) -> Result, Error> { + self.get_(db_handle, &[binds]).await + } +} + +impl LockReceipt +where + T: Serialize + for<'de> Deserialize<'de> + Send + Sync, +{ + pub async fn set( + &self, + db_handle: &mut DH, + new_value: T, + binds: (&str, &str), + ) -> Result<(), Error> { + self.set_(db_handle, new_value, &[binds.0, binds.1]).await + } + pub async fn get( + &self, + db_handle: &mut DH, + binds: (&str, &str), + ) -> Result, Error> { + self.get_(db_handle, &[binds.0, binds.1]).await + } +} diff --git a/patch-db/src/bulk_locks/unsaturated_args.rs b/patch-db/src/bulk_locks/unsaturated_args.rs new file mode 100644 index 0000000..4a357c2 --- /dev/null +++ b/patch-db/src/bulk_locks/unsaturated_args.rs @@ -0,0 +1,70 @@ +use std::marker::PhantomData; + +use crate::JsonGlob; + +/// Used to create a proof that will be consumed later to verify the amount of arguments needed to get a path. +/// One of the places that it is used is when creating a lock target +#[derive(Clone, Debug, Copy)] +pub struct UnsaturatedArgs(PhantomData); + +pub trait AsUnsaturatedArgs { + fn as_unsaturated_args(&self) -> UnsaturatedArgs; +} + +impl AsUnsaturatedArgs<()> for JsonGlob { + fn as_unsaturated_args(&self) -> UnsaturatedArgs<()> { + let count = match self { + JsonGlob::PathWithStar(path_with_star) => path_with_star.count(), + JsonGlob::Path(_) => 0, + }; + if count != 0 { + #[cfg(feature = "tracing")] + tracing::error!("By counts={}, this phantom type = () is not valid", count); + #[cfg(test)] + panic!("By counts={}, this phantom type = () is not valid", count); + } + UnsaturatedArgs(PhantomData) + } +} +impl AsUnsaturatedArgs for JsonGlob { + fn as_unsaturated_args(&self) -> UnsaturatedArgs { + let count = match self { + JsonGlob::PathWithStar(path_with_star) => path_with_star.count(), + JsonGlob::Path(_) => 0, + }; + if count != 1 { + #[cfg(feature = "tracing")] + tracing::error!( + "By counts={}, this phantom type = String is not valid", + count + ); + #[cfg(test)] + panic!( + "By counts={}, this phantom type = String is not valid", + count + ); + } + UnsaturatedArgs(PhantomData) + } +} +impl AsUnsaturatedArgs<(String, String)> for JsonGlob { + fn as_unsaturated_args(&self) -> UnsaturatedArgs<(String, String)> { + let count = match self { + JsonGlob::PathWithStar(path_with_star) => path_with_star.count(), + JsonGlob::Path(_) => 0, + }; + if count != 2 { + #[cfg(feature = "tracing")] + tracing::error!( + "By counts={}, this phantom type = (String, String) is not valid", + count + ); + #[cfg(test)] + panic!( + "By counts={}, this phantom type = (String, String) is not valid", + count + ); + } + UnsaturatedArgs(PhantomData) + } +} diff --git a/patch-db/src/handle.rs b/patch-db/src/handle.rs index 9190cb0..92718cd 100644 --- a/patch-db/src/handle.rs +++ b/patch-db/src/handle.rs @@ -8,8 +8,11 @@ use serde_json::Value; use tokio::sync::broadcast::Receiver; use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; -use crate::locker::{Guard, LockType}; -use crate::patch::DiffPatch; +use crate::{ + bulk_locks::{self, Verifier}, + locker::{Guard, LockType}, +}; +use crate::{model_paths::JsonGlob, patch::DiffPatch}; use crate::{Error, Locker, PatchDb, Revision, Store, Transaction}; #[derive(Debug, Clone, Default)] @@ -34,10 +37,13 @@ impl Ord for HandleId { self.id.cmp(&other.id) } } - #[async_trait] -pub trait DbHandle: Send + Sync { +pub trait DbHandle: Send + Sync + Sized { async fn begin<'a>(&'a mut self) -> Result, Error>; + async fn lock_all<'a>( + &'a mut self, + locks: impl IntoIterator + Send + Sync + Clone + 'a, + ) -> Result; fn id(&self) -> HandleId; fn rebase(&mut self) -> Result<(), Error>; fn store(&self) -> Arc>; @@ -68,7 +74,7 @@ pub trait DbHandle: Send + Sync { patch: DiffPatch, store_write_lock: Option>, ) -> Result>, Error>; - async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> Result<(), Error>; + async fn lock(&mut self, ptr: JsonGlob, lock_type: LockType) -> Result<(), Error>; async fn get< T: for<'de> Deserialize<'de>, S: AsRef + Send + Sync, @@ -154,7 +160,7 @@ impl DbHandle for &mut Handle { ) -> Result>, Error> { (*self).apply(patch, store_write_lock).await } - async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> Result<(), Error> { + async fn lock(&mut self, ptr: JsonGlob, lock_type: LockType) -> Result<(), Error> { (*self).lock(ptr, lock_type).await } async fn get< @@ -178,6 +184,13 @@ impl DbHandle for &mut Handle { ) -> Result>, Error> { (*self).put(ptr, value).await } + + async fn lock_all<'a>( + &'a mut self, + locks: impl IntoIterator + Send + Sync + Clone + 'a, + ) -> Result { + (*self).lock_all(locks).await + } } pub struct PatchDbHandle { @@ -266,10 +279,10 @@ impl DbHandle for PatchDbHandle { ) -> Result>, Error> { self.db.apply(patch, None, store_write_lock).await } - async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> Result<(), Error> { - Ok(self - .locks - .push(self.db.locker.lock(self.id.clone(), ptr, lock_type).await?)) + async fn lock(&mut self, ptr: JsonGlob, lock_type: LockType) -> Result<(), Error> { + self.locks + .push(self.db.locker.lock(self.id.clone(), ptr, lock_type).await?); + Ok(()) } async fn get< T: for<'de> Deserialize<'de>, @@ -292,4 +305,118 @@ impl DbHandle for PatchDbHandle { ) -> Result>, Error> { self.db.put(ptr, value, None).await } + + async fn lock_all<'a>( + &'a mut self, + locks: impl IntoIterator + Send + Sync + Clone + 'a, + ) -> Result { + let verifier = Verifier { + target_locks: locks.clone().into_iter().collect(), + }; + let guard = self.db.locker.lock_all(&self.id, locks).await?; + + self.locks.push(guard); + Ok(verifier) + } +} + +pub mod test_utils { + use async_trait::async_trait; + + use crate::{Error, Locker, Revision, Store, Transaction}; + + use super::*; + + pub struct NoOpDb(); + + #[async_trait] + impl DbHandle for NoOpDb { + async fn begin<'a>(&'a mut self) -> Result, Error> { + unimplemented!() + } + fn id(&self) -> HandleId { + unimplemented!() + } + fn rebase(&mut self) -> Result<(), Error> { + unimplemented!() + } + fn store(&self) -> Arc> { + unimplemented!() + } + fn subscribe(&self) -> Receiver> { + unimplemented!() + } + fn locker(&self) -> &Locker { + unimplemented!() + } + async fn exists + Send + Sync, V: SegList + Send + Sync>( + &mut self, + _ptr: &JsonPointer, + _store_read_lock: Option>, + ) -> Result { + unimplemented!() + } + async fn keys + Send + Sync, V: SegList + Send + Sync>( + &mut self, + _ptr: &JsonPointer, + _store_read_lock: Option>, + ) -> Result, Error> { + unimplemented!() + } + async fn get_value + Send + Sync, V: SegList + Send + Sync>( + &mut self, + _ptr: &JsonPointer, + _store_read_lock: Option>, + ) -> Result { + unimplemented!() + } + async fn put_value + Send + Sync, V: SegList + Send + Sync>( + &mut self, + _ptr: &JsonPointer, + _value: &Value, + ) -> Result>, Error> { + unimplemented!() + } + async fn apply( + &mut self, + _patch: DiffPatch, + _store_write_lock: Option>, + ) -> Result>, Error> { + unimplemented!() + } + async fn lock(&mut self, _ptr: JsonGlob, _lock_type: LockType) -> Result<(), Error> { + unimplemented!() + } + async fn get< + T: for<'de> Deserialize<'de>, + S: AsRef + Send + Sync, + V: SegList + Send + Sync, + >( + &mut self, + _ptr: &JsonPointer, + ) -> Result { + unimplemented!() + } + async fn put< + T: Serialize + Send + Sync, + S: AsRef + Send + Sync, + V: SegList + Send + Sync, + >( + &mut self, + _ptr: &JsonPointer, + _value: &T, + ) -> Result>, Error> { + unimplemented!() + } + + async fn lock_all<'a>( + &'a mut self, + locks: impl IntoIterator + Send + Sync + Clone + 'a, + ) -> Result { + let skeleton_key = Verifier { + target_locks: locks.into_iter().collect(), + }; + Ok(skeleton_key) + } + } } diff --git a/patch-db/src/lib.rs b/patch-db/src/lib.rs index 1f98bc9..a71e6e8 100644 --- a/patch-db/src/lib.rs +++ b/patch-db/src/lib.rs @@ -8,9 +8,11 @@ use tokio::sync::broadcast::error::TryRecvError; // note: inserting into an array (before another element) without proper locking can result in unexpected behaviour +mod bulk_locks; mod handle; mod locker; mod model; +mod model_paths; mod patch; mod store; mod transaction; @@ -23,12 +25,20 @@ pub use locker::{LockType, Locker}; pub use model::{ BoxModel, HasModel, Map, MapModel, Model, ModelData, ModelDataMut, OptionModel, VecModel, }; +pub use model_paths::{JsonGlob, JsonGlobSegment}; pub use patch::{DiffPatch, Dump, Revision}; pub use patch_db_macro::HasModel; pub use store::{PatchDb, Store}; pub use transaction::Transaction; pub use {json_patch, json_ptr}; +pub use bulk_locks::{LockReceipt, LockTarget, LockTargetId, Verifier}; + +pub mod test_utils { + use super::*; + pub use handle::test_utils::*; +} + #[derive(Error, Debug)] pub enum Error { #[error("IO Error: {0}")] @@ -53,4 +63,6 @@ pub enum Error { NodeDoesNotExist(JsonPointer), #[error("Invalid Lock Request: {0}")] LockError(#[from] LockError), + #[error("Invalid Lock Request: {0}")] + Locker(String), } diff --git a/patch-db/src/locker/action_mux.rs b/patch-db/src/locker/action_mux.rs index b40baeb..1095be3 100644 --- a/patch-db/src/locker/action_mux.rs +++ b/patch-db/src/locker/action_mux.rs @@ -2,13 +2,13 @@ use tokio::sync::mpsc::{self, UnboundedReceiver}; use tokio::sync::oneshot; use tokio::sync::oneshot::error::TryRecvError; -use super::{LockInfo, Request}; +use super::{LockInfos, Request}; #[derive(Debug)] pub(super) enum Action { HandleRequest(Request), - HandleRelease(LockInfo), - HandleCancel(LockInfo), + HandleRelease(LockInfos), + HandleCancel(LockInfos), } struct InboundRequestQueue { @@ -17,9 +17,9 @@ struct InboundRequestQueue { } pub(super) struct ActionMux { inbound_request_queue: InboundRequestQueue, - unlock_receivers: Vec>, - cancellation_receivers: Vec>, - _dummy_senders: Vec>, + unlock_receivers: Vec>, + cancellation_receivers: Vec>, + _dummy_senders: Vec>, } impl ActionMux { pub fn new(inbound_receiver: UnboundedReceiver) -> Self { @@ -106,14 +106,14 @@ impl ActionMux { } } - pub fn push_unlock_receivers>>( + pub fn push_unlock_receivers>>( &mut self, recv: T, ) { self.unlock_receivers.extend(recv) } - pub fn push_cancellation_receiver(&mut self, recv: oneshot::Receiver) { + pub fn push_cancellation_receiver(&mut self, recv: oneshot::Receiver) { self.cancellation_receivers.push(recv) } } diff --git a/patch-db/src/locker/bookkeeper.rs b/patch-db/src/locker/bookkeeper.rs index 478ef0b..055c374 100644 --- a/patch-db/src/locker/bookkeeper.rs +++ b/patch-db/src/locker/bookkeeper.rs @@ -5,9 +5,10 @@ use tokio::sync::oneshot; #[cfg(feature = "tracing")] use tracing::{debug, error, info, warn}; +#[cfg(feature = "unstable")] use super::order_enforcer::LockOrderEnforcer; use super::trie::LockTrie; -use super::{LockError, LockInfo, Request}; +use super::{LockError, LockInfos, Request}; use crate::handle::HandleId; #[cfg(feature = "tracing")] use crate::locker::log_utils::{ @@ -35,7 +36,7 @@ impl LockBookkeeper { pub fn lease( &mut self, req: Request, - ) -> Result>, LockError> { + ) -> Result>, LockError> { #[cfg(feature = "unstable")] if let Err(e) = self.order_enforcer.try_insert(&req.lock_info) { req.reject(e.clone()); @@ -53,45 +54,69 @@ impl LockBookkeeper { if let Some(hot_seat) = hot_seat { self.deferred_request_queue.push_front(hot_seat); - kill_deadlocked(&mut self.deferred_request_queue, &mut self.trie); + kill_deadlocked(&mut self.deferred_request_queue, &self.trie); } Ok(res) } - pub fn cancel(&mut self, info: &LockInfo) { + pub fn cancel(&mut self, info: &LockInfos) { #[cfg(feature = "unstable")] - self.order_enforcer.remove(&info); + for info in info.as_vec() { + self.order_enforcer.remove(&info); + } let entry = self .deferred_request_queue .iter() .enumerate() .find(|(_, (r, _))| &r.lock_info == info); - match entry { + let index = match entry { None => { #[cfg(feature = "tracing")] - warn!( - "Received cancellation for a lock not currently waiting: {}", - info.ptr - ); + { + let infos = &info.0; + warn!( + "Received cancellation for some locks not currently waiting: [{}]", + infos + .iter() + .enumerate() + .fold(String::new(), |acc, (i, new)| { + if i > 0 { + format!("{}/{}", acc, new.ptr) + } else { + format!("/{}", new.ptr) + } + }) + ); + } + return; } - Some((i, (req, _))) => { + Some(value) => { #[cfg(feature = "tracing")] - info!("{}", fmt_cancelled(&req.lock_info)); - - self.deferred_request_queue.remove(i); + for lock_info in value.1 .0.lock_info.as_vec() { + info!("{}", fmt_cancelled(lock_info)); + } + value.0 } - } + }; + + self.deferred_request_queue.remove(index); } - pub fn ret(&mut self, info: &LockInfo) -> Vec> { + pub fn ret(&mut self, info: &LockInfos) -> Vec> { #[cfg(feature = "unstable")] - self.order_enforcer.remove(&info); - self.trie.unlock(&info); + for info in info.as_vec() { + self.order_enforcer.remove(&info); + } + for info in info.as_vec() { + self.trie.unlock(info); + } #[cfg(feature = "tracing")] { - info!("{}", fmt_released(&info)); + for info in info.as_vec() { + info!("{}", fmt_released(&info)); + } debug!("Reexamining request queue backlog..."); } @@ -127,7 +152,7 @@ impl LockBookkeeper { } if let Some(hot_seat) = hot_seat { self.deferred_request_queue.push_front(hot_seat); - kill_deadlocked(&mut self.deferred_request_queue, &mut self.trie); + kill_deadlocked(&mut self.deferred_request_queue, &self.trie); } new_unlock_receivers } @@ -141,21 +166,31 @@ fn process_new_req( hot_seat: Option<&(Request, OrdSet)>, trie: &mut LockTrie, request_queue: &mut VecDeque<(Request, OrdSet)>, -) -> Option> { +) -> Option> { + #[cfg(feature = "tracing")] + let lock_infos = req.lock_info.as_vec(); match hot_seat { // hot seat conflicts and request session isn't in current blocking sessions // so we push it to the queue Some((hot_req, hot_blockers)) if hot_req.lock_info.conflicts_with(&req.lock_info) - && !hot_blockers.contains(&req.lock_info.handle_id) => + && !req + .lock_info + .as_vec() + .iter() + .any(|lock_info| hot_blockers.contains(&lock_info.handle_id)) => { #[cfg(feature = "tracing")] { - info!("{}", fmt_deferred(&req.lock_info)); - debug!( - "Must wait on hot seat request from session {}", - &hot_req.lock_info.handle_id.id - ); + for lock_info in lock_infos.iter() { + info!("{}", fmt_deferred(&lock_info)); + } + if let Some(hot_req_lock_info) = hot_req.lock_info.as_vec().first() { + debug!( + "Must wait on hot seat request from session {}", + &hot_req_lock_info.handle_id.id + ); + } } request_queue.push_back((req, ordset![])); @@ -165,14 +200,18 @@ fn process_new_req( _ => match trie.try_lock(&req.lock_info) { Ok(()) => { #[cfg(feature = "tracing")] - info!("{}", fmt_acquired(&req.lock_info)); + for lock_info in lock_infos.iter() { + info!("{}", fmt_acquired(&lock_info)); + } Some(req.complete()) } Err(blocking_sessions) => { #[cfg(feature = "tracing")] { - info!("{}", fmt_deferred(&req.lock_info)); + for lock_info in lock_infos.iter() { + info!("{}", fmt_deferred(&lock_info)); + } debug!( "Must wait on sessions {}", display_session_set(&blocking_sessions) @@ -200,7 +239,13 @@ fn kill_deadlocked(request_queue: &mut VecDeque<(Request, OrdSet)>, tr error!("Deadlock Detected: {:?}", locks_waiting); let err = LockError::DeadlockDetected { locks_waiting, - locks_held: LockSet(trie.subtree_lock_info()), + locks_held: LockSet( + trie.subtree_lock_info() + .into_iter() + .map(|x| vec![x]) + .map(LockInfos) + .collect(), + ), }; let mut indices_to_remove = Vec::with_capacity(deadlocked_reqs.len()); @@ -220,15 +265,23 @@ fn kill_deadlocked(request_queue: &mut VecDeque<(Request, OrdSet)>, tr } } -pub(super) fn deadlock_scan<'a>( - queue: &'a VecDeque<(Request, OrdSet)>, -) -> Vec<&'a Request> { +pub(super) fn deadlock_scan(queue: &VecDeque<(Request, OrdSet)>) -> Vec<&'_ Request> { let (wait_map, mut req_map) = queue .iter() - .map(|(req, set)| ((&req.lock_info.handle_id, set, req))) + .flat_map(|(req, set)| { + req.lock_info + .as_vec() + .into_iter() + .map(|lock_info| (&lock_info.handle_id, set, req)) + .collect::>() + }) .fold( (ordmap! {}, ordmap! {}), - |(mut wmap, mut rmap), (id, wset, req)| { + |(mut wmap, mut rmap): ( + OrdMap<&HandleId, &OrdSet>, + OrdMap<&HandleId, &Request>, + ), + (id, wset, req)| { ( { wmap.insert(id, wset); diff --git a/patch-db/src/locker/mod.rs b/patch-db/src/locker/mod.rs index edb4cdc..040eba5 100644 --- a/patch-db/src/locker/mod.rs +++ b/patch-db/src/locker/mod.rs @@ -1,5 +1,6 @@ mod action_mux; mod bookkeeper; +#[cfg(feature = "tracing")] mod log_utils; mod natural; mod order_enforcer; @@ -8,15 +9,14 @@ pub(crate) mod proptest; mod trie; use imbl::{ordmap, ordset, OrdMap, OrdSet}; -use json_ptr::JsonPointer; use tokio::sync::{mpsc, oneshot}; #[cfg(feature = "tracing")] use tracing::{debug, trace, warn}; use self::action_mux::ActionMux; use self::bookkeeper::LockBookkeeper; -use crate::handle::HandleId; -use crate::locker::action_mux::Action; +use crate::{bulk_locks::LockTargetId, locker::action_mux::Action, Verifier}; +use crate::{handle::HandleId, JsonGlob}; pub struct Locker { sender: mpsc::UnboundedSender, @@ -74,15 +74,44 @@ impl Locker { pub async fn lock( &self, handle_id: HandleId, - ptr: JsonPointer, + ptr: JsonGlob, lock_type: LockType, ) -> Result { // Pertinent Logic - let lock_info = LockInfo { + let lock_info: LockInfos = LockInfo { handle_id, ptr, ty: lock_type, - }; + } + .into(); + self._lock(lock_info).await + } + + pub async fn lock_all( + &self, + handle_id: &HandleId, + locks: impl IntoIterator + Send, + ) -> Result { + let lock_infos = LockInfos( + locks + .into_iter() + .map( + |LockTargetId { + glob: ptr, + lock_type: ty, + }| { + LockInfo { + handle_id: handle_id.clone(), + ptr, + ty, + } + }, + ) + .collect(), + ); + self._lock(lock_infos).await + } + async fn _lock(&self, lock_info: LockInfos) -> Result { let (send, recv) = oneshot::channel(); let (cancel_send, cancel_recv) = oneshot::channel(); let mut cancel_guard = CancelGuard { @@ -101,11 +130,11 @@ impl Locker { cancel_guard.channel.take(); res } -} // Local Definitions +} #[derive(Debug)] struct CancelGuard { - lock_info: Option, - channel: Option>, + lock_info: Option, + channel: Option>, recv: oneshot::Receiver>, } impl Drop for CancelGuard { @@ -117,10 +146,37 @@ impl Drop for CancelGuard { } } +#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)] +struct LockInfos(pub Vec); +impl LockInfos { + fn conflicts_with(&self, other: &LockInfos) -> bool { + let other_lock_infos = &other.0; + self.0.iter().any(|lock_info| { + other_lock_infos + .iter() + .any(|other_lock_info| lock_info.conflicts_with(other_lock_info)) + }) + } + + fn as_vec(&self) -> &Vec { + &self.0 + } +} + +impl std::fmt::Display for LockInfos { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let lock_infos = &self.0; + for lock_info in lock_infos { + write!(f, "{},", lock_info)?; + } + Ok(()) + } +} + #[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord)] struct LockInfo { handle_id: HandleId, - ptr: JsonPointer, + ptr: JsonGlob, ty: LockType, } impl LockInfo { @@ -144,6 +200,7 @@ impl LockInfo { } } } + #[cfg(any(feature = "unstable", test))] fn implicitly_grants(&self, other: &LockInfo) -> bool { self.handle_id == other.handle_id && match self.ty { @@ -163,6 +220,12 @@ impl LockInfo { } } } + +impl From for LockInfos { + fn from(lock_info: LockInfo) -> Self { + LockInfos(vec![lock_info]) + } +} impl std::fmt::Display for LockInfo { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}{}{}", self.handle_id.id, self.ty, self.ptr) @@ -192,16 +255,17 @@ impl std::fmt::Display for LockType { } #[derive(Debug, Clone)] -pub struct LockSet(OrdSet); +pub struct LockSet(OrdSet); impl std::fmt::Display for LockSet { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let by_session = self .0 .iter() + .flat_map(|x| x.as_vec()) .map(|i| (&i.handle_id, ordset![(&i.ptr, &i.ty)])) .fold( ordmap! {}, - |m: OrdMap<&HandleId, OrdSet<(&JsonPointer, &LockType)>>, (id, s)| { + |m: OrdMap<&HandleId, OrdSet<(&JsonGlob, &LockType)>>, (id, s)| { m.update_with(&id, s, OrdSet::union) }, ); @@ -230,22 +294,22 @@ pub enum LockError { #[error("Lock Taxonomy Escalation: Session = {session:?}, First = {first}, Second = {second}")] LockTaxonomyEscalation { session: HandleId, - first: JsonPointer, - second: JsonPointer, + first: JsonGlob, + second: JsonGlob, }, #[error("Lock Type Escalation: Session = {session:?}, Pointer = {ptr}, First = {first}, Second = {second}")] LockTypeEscalation { session: HandleId, - ptr: JsonPointer, + ptr: JsonGlob, first: LockType, second: LockType, }, #[error("Lock Type Escalation Implicit: Session = {session:?}, First = {first_ptr}:{first_type}, Second = {second_ptr}:{second_type}")] LockTypeEscalationImplicit { session: HandleId, - first_ptr: JsonPointer, + first_ptr: JsonGlob, first_type: LockType, - second_ptr: JsonPointer, + second_ptr: JsonGlob, second_type: LockType, }, #[error( @@ -253,8 +317,8 @@ pub enum LockError { )] NonCanonicalOrdering { session: HandleId, - first: JsonPointer, - second: JsonPointer, + first: JsonGlob, + second: JsonGlob, }, #[error("Deadlock Detected:\nLocks Held =\n{locks_held},\nLocks Waiting =\n{locks_waiting}")] DeadlockDetected { @@ -265,12 +329,12 @@ pub enum LockError { #[derive(Debug)] struct Request { - lock_info: LockInfo, - cancel: Option>, + lock_info: LockInfos, + cancel: Option>, completion: oneshot::Sender>, } impl Request { - fn complete(self) -> oneshot::Receiver { + fn complete(self) -> oneshot::Receiver { let (sender, receiver) = oneshot::channel(); if let Err(_) = self.completion.send(Ok(Guard { lock_info: self.lock_info, @@ -291,8 +355,8 @@ impl Request { #[derive(Debug)] pub struct Guard { - lock_info: LockInfo, - sender: Option>, + lock_info: LockInfos, + sender: Option>, } impl Drop for Guard { fn drop(&mut self) { @@ -307,3 +371,100 @@ impl Drop for Guard { } } } + +#[test] +fn conflicts_with_locker_infos_cases() { + let mut id: u64 = 0; + let lock_info_a = LockInfo { + handle_id: HandleId { + id: { + id += 1; + id + }, + #[cfg(feature = "trace")] + trace: None, + }, + ty: LockType::Write, + ptr: "/a".parse().unwrap(), + }; + let lock_infos_a = LockInfos(vec![lock_info_a.clone()]); + let lock_info_b = LockInfo { + handle_id: HandleId { + id: { + id += 1; + id + }, + #[cfg(feature = "trace")] + trace: None, + }, + ty: LockType::Write, + ptr: "/b".parse().unwrap(), + }; + let lock_infos_b = LockInfos(vec![lock_info_b.clone()]); + let lock_info_a_s = LockInfo { + handle_id: HandleId { + id: { + id += 1; + id + }, + #[cfg(feature = "trace")] + trace: None, + }, + ty: LockType::Write, + ptr: "/a/*".parse().unwrap(), + }; + let lock_infos_a_s = LockInfos(vec![lock_info_a_s.clone()]); + let lock_info_a_s_c = LockInfo { + handle_id: HandleId { + id: { + id += 1; + id + }, + #[cfg(feature = "trace")] + trace: None, + }, + ty: LockType::Write, + ptr: "/a/*/c".parse().unwrap(), + }; + let lock_infos_a_s_c = LockInfos(vec![lock_info_a_s_c.clone()]); + + let lock_info_a_b_c = LockInfo { + handle_id: HandleId { + id: { + id += 1; + id + }, + #[cfg(feature = "trace")] + trace: None, + }, + ty: LockType::Write, + ptr: "/a/b/c".parse().unwrap(), + }; + let lock_infos_a_b_c = LockInfos(vec![lock_info_a_b_c.clone()]); + + let lock_infos_set = LockInfos(vec![lock_info_a.clone()]); + let lock_infos_set_b = LockInfos(vec![lock_info_b]); + let lock_infos_set_deep = LockInfos(vec![ + lock_info_a_s.clone(), + lock_info_a_s_c.clone(), + lock_info_a_b_c.clone(), + ]); + let lock_infos_set_all = LockInfos(vec![ + lock_info_a, + lock_info_a_s, + lock_info_a_s_c, + lock_info_a_b_c, + ]); + + assert!(!lock_infos_b.conflicts_with(&lock_infos_a)); + assert!(!lock_infos_a.conflicts_with(&lock_infos_a)); // same lock won't + assert!(lock_infos_a_s.conflicts_with(&lock_infos_a)); // Since the parent is locked, it won't be able to + assert!(lock_infos_a_s.conflicts_with(&lock_infos_a_s_c)); + assert!(lock_infos_a_s_c.conflicts_with(&lock_infos_a_b_c)); + assert!(!lock_infos_set.conflicts_with(&lock_infos_a)); // Same lock again + assert!(lock_infos_set.conflicts_with(&lock_infos_set_deep)); // Since this is a parent + assert!(!lock_infos_set_b.conflicts_with(&lock_infos_set_deep)); // Sets are exclusive + assert!(!lock_infos_set.conflicts_with(&lock_infos_set_b)); // Sets are exclusive + assert!(lock_infos_set_deep.conflicts_with(&lock_infos_set)); // Shared parent a + assert!(lock_infos_set_deep.conflicts_with(&lock_infos_set_all)); // Shared parent a +} diff --git a/patch-db/src/locker/order_enforcer.rs b/patch-db/src/locker/order_enforcer.rs index a1c2d0b..f878454 100644 --- a/patch-db/src/locker/order_enforcer.rs +++ b/patch-db/src/locker/order_enforcer.rs @@ -1,22 +1,27 @@ -use imbl::{ordmap, OrdMap}; -use json_ptr::JsonPointer; +use imbl::OrdMap; #[cfg(feature = "tracing")] use tracing::warn; -use super::{LockError, LockInfo}; -use crate::handle::HandleId; +use super::LockInfo; use crate::LockType; +use crate::{handle::HandleId, model_paths::JsonGlob}; + +#[cfg(any(feature = "unstable", test))] +use super::LockError; #[derive(Debug, PartialEq, Eq)] pub(super) struct LockOrderEnforcer { - locks_held: OrdMap>, + locks_held: OrdMap>, } impl LockOrderEnforcer { + #[cfg(any(feature = "unstable", test))] pub fn new() -> Self { LockOrderEnforcer { - locks_held: ordmap! {}, + locks_held: imbl::ordmap! {}, } } + #[cfg_attr(feature = "trace", tracing::instrument)] + #[cfg(any(feature = "unstable", test))] // locks must be acquired in lexicographic order for the pointer, and reverse order for type fn validate(&self, req: &LockInfo) -> Result<(), LockError> { // the following notation is used to denote an example sequence that can cause deadlocks @@ -94,29 +99,36 @@ impl LockOrderEnforcer { } } } - pub(super) fn try_insert(&mut self, req: &LockInfo) -> Result<(), LockError> { - self.validate(req)?; - match self.locks_held.get_mut(&req.handle_id) { - None => { - self.locks_held.insert( - req.handle_id.clone(), - ordmap![(req.ptr.clone(), req.ty) => 1], - ); - } - Some(locks) => { - let k = (req.ptr.clone(), req.ty); - match locks.get_mut(&k) { - None => { - locks.insert(k, 1); - } - Some(n) => { - *n += 1; + #[cfg(any(feature = "unstable", test))] + pub(super) fn try_insert(&mut self, reqs: &super::LockInfos) -> Result<(), LockError> { + // These are seperate since we want to check all first before we insert + for req in reqs.as_vec() { + self.validate(req)?; + } + for req in reqs.as_vec() { + match self.locks_held.get_mut(&req.handle_id) { + None => { + self.locks_held.insert( + req.handle_id.clone(), + imbl::ordmap![(req.ptr.clone(), req.ty) => 1], + ); + } + Some(locks) => { + let k = (req.ptr.clone(), req.ty); + match locks.get_mut(&k) { + None => { + locks.insert(k, 1); + } + Some(n) => { + *n += 1; + } } } } } Ok(()) } + #[cfg(any(feature = "unstable", test))] pub(super) fn remove(&mut self, req: &LockInfo) { match self.locks_held.remove_with_key(&req.handle_id) { None => { diff --git a/patch-db/src/locker/proptest.rs b/patch-db/src/locker/proptest.rs index 79ce88e..be799ad 100644 --- a/patch-db/src/locker/proptest.rs +++ b/patch-db/src/locker/proptest.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod tests { - use std::collections::{HashMap, VecDeque}; + use std::collections::VecDeque; use imbl::{ordmap, ordset, OrdMap, OrdSet}; use json_ptr::JsonPointer; @@ -10,9 +10,12 @@ mod tests { use tokio::sync::oneshot; use crate::handle::HandleId; - use crate::locker::bookkeeper::{deadlock_scan, path_to}; - use crate::locker::{CancelGuard, Guard, LockError, LockInfo, LockType, Request}; + use crate::locker::{CancelGuard, LockError, LockInfo, LockInfos, LockType, Request}; use crate::Locker; + use crate::{ + locker::bookkeeper::{deadlock_scan, path_to}, + JsonGlob, + }; // enum Action { // Acquire { @@ -66,11 +69,22 @@ mod tests { }) .boxed() } + fn arb_model_paths(max_size: usize) -> impl Strategy { + proptest::collection::vec("[a-z*]", 1..max_size).prop_map(|a_s| { + a_s.into_iter() + .fold(String::new(), |mut s, x| { + s.push_str(&x); + s + }) + .parse::() + .unwrap() + }) + } fn arb_lock_info(session_bound: u64, ptr_max_size: usize) -> BoxedStrategy { arb_handle_id(session_bound) .prop_flat_map(move |handle_id| { - arb_json_ptr(ptr_max_size).prop_flat_map(move |ptr| { + arb_model_paths(ptr_max_size).prop_flat_map(move |ptr| { let handle_id = handle_id.clone(); arb_lock_type().prop_map(move |ty| LockInfo { handle_id: handle_id.clone(), @@ -81,19 +95,39 @@ mod tests { }) .boxed() } + fn arb_lock_infos( + session_bound: u64, + ptr_max_size: usize, + max_size: usize, + ) -> BoxedStrategy { + arb_handle_id(session_bound) + .prop_flat_map(move |handle_id| { + proptest::collection::vec(arb_lock_info(session_bound, ptr_max_size), 1..max_size) + .prop_map(move |xs| { + xs.into_iter() + .map(|mut x| { + x.handle_id = handle_id.clone(); + x + }) + .collect::>() + }) + }) + .prop_map(LockInfos) + .boxed() + } prop_compose! { - fn arb_request(session_bound: u64, ptr_max_size: usize)(li in arb_lock_info(session_bound, ptr_max_size)) -> (Request, CancelGuard) { + fn arb_request(session_bound: u64, ptr_max_size: usize)(lis in arb_lock_infos(session_bound, ptr_max_size, 10)) -> (Request, CancelGuard) { let (cancel_send, cancel_recv) = oneshot::channel(); let (guard_send, guard_recv) = oneshot::channel(); let r = Request { - lock_info: li.clone(), + lock_info: lis.clone(), cancel: Some(cancel_recv), completion: guard_send, }; let c = CancelGuard { - lock_info: Some(li), + lock_info: Some(lis), channel: Some(cancel_send), recv: guard_recv, }; @@ -152,7 +186,12 @@ mod tests { let mut queue = VecDeque::default(); for i in 0..n { let mut req = arb_request(1, 5).new_tree(&mut runner).unwrap().current(); - req.0.lock_info.handle_id.id = i; + match req.0.lock_info { + LockInfos(ref mut li) => { + li[0].handle_id.id = i; + } + _ => unreachable!(), + } let dep = if i == n - 1 { 0 } else { i + 1 }; queue.push_back(( req.0, @@ -165,7 +204,9 @@ mod tests { c.push_back(req.1); } for i in &queue { - println!("{} => {:?}", i.0.lock_info.handle_id.id, i.1) + for info in i.0.lock_info.as_vec() { + println!("{} => {:?}", info.handle_id.id, i.1) + } } let set = deadlock_scan(&queue); println!("{:?}", set); @@ -190,7 +231,7 @@ mod tests { let h = arb_handle_id(5).new_tree(&mut runner).unwrap().current(); let i = (0..queue.len()).new_tree(&mut runner).unwrap().current(); if let Some((r, s)) = queue.get_mut(i) { - if r.lock_info.handle_id != h { + if r.lock_info.as_vec().iter().all(|x| x.handle_id != h) { s.insert(h); } else { continue; @@ -199,24 +240,34 @@ mod tests { } else { // add new node let (r, c) = arb_request(5, 5).new_tree(&mut runner).unwrap().current(); + let request_infos = r.lock_info.as_vec(); // but only if the session hasn't yet been used - if queue - .iter() - .all(|(qr, _)| qr.lock_info.handle_id.id != r.lock_info.handle_id.id) - { + if queue.iter().all(|(qr, _)| { + for qr_info in qr.lock_info.as_vec() { + for request_info in request_infos.iter() { + if qr_info.handle_id == request_info.handle_id { + return false; + } + } + } + true + }) { queue.push_back((r, ordset![])); cancels.push_back(c); } } let cycle = deadlock_scan(&queue) .into_iter() - .map(|r| &r.lock_info.handle_id) + .flat_map(|x| x.lock_info.as_vec()) + .map(|r| &r.handle_id) .collect::>(); if !cycle.is_empty() { println!("Cycle: {:?}", cycle); for (r, s) in &queue { - if cycle.contains(&r.lock_info.handle_id) { - assert!(s.iter().any(|h| cycle.contains(h))) + for info in r.lock_info.as_vec() { + if cycle.contains(&info.handle_id) { + assert!(s.iter().any(|h| cycle.contains(h))) + } } } break; @@ -272,7 +323,7 @@ mod tests { proptest! { #[test] - fn trie_lock_inverse_identity(lock_order in proptest::collection::vec(arb_lock_info(1, 5), 1..30)) { + fn trie_lock_inverse_identity(lock_order in proptest::collection::vec(arb_lock_infos(1, 5, 10), 1..30)) { use crate::locker::trie::LockTrie; use rand::seq::SliceRandom; let mut trie = LockTrie::default(); @@ -280,10 +331,12 @@ mod tests { trie.try_lock(i).expect(&format!("try_lock failed: {}", i)); } let mut release_order = lock_order.clone(); - let slice: &mut [LockInfo] = &mut release_order[..]; + let slice: &mut [LockInfos] = &mut release_order[..]; slice.shuffle(&mut rand::thread_rng()); - for i in &release_order { - trie.unlock(i); + for is in &release_order { + for i in is.as_vec() { + trie.unlock(i); + } } prop_assert_eq!(trie, LockTrie::default()) } @@ -291,7 +344,7 @@ mod tests { proptest! { #[test] - fn enforcer_lock_inverse_identity(lock_order in proptest::collection::vec(arb_lock_info(1,3), 1..30)) { + fn enforcer_lock_inverse_identity(lock_order in proptest::collection::vec(arb_lock_infos(1,3,10), 1..30)) { use crate::locker::order_enforcer::LockOrderEnforcer; use rand::seq::SliceRandom; let mut enforcer = LockOrderEnforcer::new(); @@ -299,11 +352,13 @@ mod tests { enforcer.try_insert(i); } let mut release_order = lock_order.clone(); - let slice: &mut [LockInfo] = &mut release_order[..]; + let slice: &mut [LockInfos] = &mut release_order[..]; slice.shuffle(&mut rand::thread_rng()); prop_assert!(enforcer != LockOrderEnforcer::new()); - for i in &release_order { - enforcer.remove(i); + for is in &release_order { + for i in is.as_vec() { + enforcer.remove(i); + } } prop_assert_eq!(enforcer, LockOrderEnforcer::new()); } @@ -318,7 +373,7 @@ mod tests { let li0 = LockInfo { handle_id: s0, ty: LockType::Exist, - ptr: ptr0.clone() + ptr: ptr0.clone().into() }; println!("{}", ptr0); ptr0.append(&ptr1); @@ -326,11 +381,11 @@ mod tests { let li1 = LockInfo { handle_id: s1, ty: LockType::Write, - ptr: ptr0.clone() + ptr: ptr0.clone().into() }; - trie.try_lock(&li0).unwrap(); + trie.try_lock(&LockInfos(vec![li0])).unwrap(); println!("{:?}", trie); - trie.try_lock(&li1).expect("E locks don't prevent child locks"); + trie.try_lock(&LockInfos(vec![li1])).expect("E locks don't prevent child locks"); } } diff --git a/patch-db/src/locker/trie.rs b/patch-db/src/locker/trie.rs index b56dd21..0f27c41 100644 --- a/patch-db/src/locker/trie.rs +++ b/patch-db/src/locker/trie.rs @@ -3,10 +3,10 @@ use std::collections::BTreeMap; use imbl::{ordset, OrdSet}; use json_ptr::{JsonPointer, SegList}; -use super::natural::Natural; use super::LockInfo; -use crate::handle::HandleId; -use crate::LockType; +use super::{natural::Natural, LockInfos}; +use crate::{handle::HandleId, model_paths::JsonGlob}; +use crate::{model_paths::JsonGlobSegment, LockType}; #[derive(Debug, Clone, PartialEq, Eq)] enum LockState { @@ -51,10 +51,7 @@ impl LockState { } } fn write_free(&self) -> bool { - match self { - LockState::Exclusive { .. } => false, - _ => true, - } + !matches!(self, LockState::Exclusive { .. }) } fn read_free(&self) -> bool { match self { @@ -65,7 +62,7 @@ impl LockState { _ => true, } } - fn sessions<'a>(&'a self) -> OrdSet<&'a HandleId> { + fn sessions(&self) -> OrdSet<&'_ HandleId> { match self { LockState::Free => OrdSet::new(), LockState::Shared { @@ -76,7 +73,7 @@ impl LockState { } } #[allow(dead_code)] - fn exist_sessions<'a>(&'a self) -> OrdSet<&'a HandleId> { + fn exist_sessions(&self) -> OrdSet<&'_ HandleId> { match self { LockState::Free => OrdSet::new(), LockState::Shared { e_lessees, .. } => e_lessees.keys().collect(), @@ -93,7 +90,7 @@ impl LockState { } } } - fn read_sessions<'a>(&'a self) -> OrdSet<&'a HandleId> { + fn read_sessions(&self) -> OrdSet<&'_ HandleId> { match self { LockState::Free => OrdSet::new(), LockState::Shared { r_lessees, .. } => r_lessees.keys().collect(), @@ -110,7 +107,7 @@ impl LockState { } } } - fn write_session<'a>(&'a self) -> Option<&'a HandleId> { + fn write_session(&self) -> Option<&'_ HandleId> { match self { LockState::Exclusive { w_lessee, .. } => Some(w_lessee), _ => None, @@ -373,12 +370,9 @@ impl LockTrie { } #[allow(dead_code)] fn subtree_is_exclusive_free_for(&self, session: &HandleId) -> bool { - self.all(|s| match s.clone().erase(session) { - LockState::Exclusive { .. } => false, - _ => true, - }) + self.all(|s| !matches!(s.clone().erase(session), LockState::Exclusive { .. })) } - fn subtree_write_sessions<'a>(&'a self) -> OrdSet<&'a HandleId> { + fn subtree_write_sessions(&self) -> OrdSet<&'_ HandleId> { match &self.state { LockState::Exclusive { w_lessee, .. } => ordset![w_lessee], _ => self @@ -388,7 +382,7 @@ impl LockTrie { .fold(OrdSet::new(), OrdSet::union), } } - fn subtree_sessions<'a>(&'a self) -> OrdSet<&'a HandleId> { + fn subtree_sessions(&self) -> OrdSet<&'_ HandleId> { let children = self .children .values() @@ -396,19 +390,25 @@ impl LockTrie { .fold(OrdSet::new(), OrdSet::union); self.state.sessions().union(children) } - pub fn subtree_lock_info<'a>(&'a self) -> OrdSet { + pub fn subtree_lock_info(&self) -> OrdSet { let mut acc = self .children .iter() .map(|(s, t)| { t.subtree_lock_info() .into_iter() - .map(|mut i| LockInfo { + .map(|i| LockInfo { ty: i.ty, handle_id: i.handle_id, ptr: { - i.ptr.push_start(s); - i.ptr + i.ptr.append(s.parse().unwrap_or_else(|_| { + #[cfg(feature = "tracing")] + tracing::error!( + "Should never not be able to parse a string as a path" + ); + + Default::default() + })) }, }) .collect() @@ -416,7 +416,7 @@ impl LockTrie { .fold(ordset![], OrdSet::union); let self_writes = self.state.write_session().map(|session| LockInfo { handle_id: session.clone(), - ptr: JsonPointer::default(), + ptr: Default::default(), ty: LockType::Write, }); let self_reads = self @@ -425,7 +425,7 @@ impl LockTrie { .into_iter() .map(|session| LockInfo { handle_id: session.clone(), - ptr: JsonPointer::default(), + ptr: Default::default(), ty: LockType::Read, }); let self_exists = self @@ -434,13 +434,13 @@ impl LockTrie { .into_iter() .map(|session| LockInfo { handle_id: session.clone(), - ptr: JsonPointer::default(), + ptr: Default::default(), ty: LockType::Exist, }); acc.extend(self_writes.into_iter().chain(self_reads).chain(self_exists)); acc } - fn ancestors_and_trie<'a, S: AsRef, V: SegList>( + fn ancestors_and_trie_json_path<'a, S: AsRef, V: SegList>( &'a self, ptr: &JsonPointer, ) -> (Vec<&'a LockState>, Option<&'a LockTrie>) { @@ -449,107 +449,196 @@ impl LockTrie { Some((first, rest)) => match self.children.get(first) { None => (vec![&self.state], None), Some(t) => { - let (mut v, t) = t.ancestors_and_trie(&rest); + let (mut v, t) = t.ancestors_and_trie_json_path(&rest); v.push(&self.state); (v, t) } }, } } + fn ancestors_and_trie_model_paths<'a>( + &'a self, + path: &[JsonGlobSegment], + ) -> Vec<(Vec<&'a LockState>, Option<&'a LockTrie>)> { + let head = path.get(0); + match head { + None => vec![(Vec::new(), Some(self))], + Some(JsonGlobSegment::Star) => self + .children + .values() + .into_iter() + .flat_map(|lock_trie| lock_trie.ancestors_and_trie_model_paths(&path[1..])) + .collect(), + Some(JsonGlobSegment::Path(x)) => match self.children.get(x) { + None => vec![(vec![&self.state], None)], + Some(t) => t + .ancestors_and_trie_model_paths(&path[1..]) + .into_iter() + .map(|(mut v, t)| { + v.push(&self.state); + (v, t) + }) + .collect(), + }, + } + } + fn ancestors_and_trie<'a>( + &'a self, + ptr: &JsonGlob, + ) -> Vec<(Vec<&'a LockState>, Option<&'a LockTrie>)> { + match ptr { + JsonGlob::Path(x) => vec![self.ancestors_and_trie_json_path(x)], + JsonGlob::PathWithStar(path) => self.ancestors_and_trie_model_paths(path.segments()), + } + } // no writes in ancestor set, no writes at node #[allow(dead_code)] - fn can_acquire_exist(&self, ptr: &JsonPointer, session: &HandleId) -> bool { - let (v, t) = self.ancestors_and_trie(ptr); - let ancestor_write_free = v + fn can_acquire_exist(&self, ptr: &JsonGlob, session: &HandleId) -> bool { + let (vectors, tries): (Vec<_>, Vec<_>) = self.ancestors_and_trie(ptr).into_iter().unzip(); + let ancestor_write_free = vectors.into_iter().all(|v| { + v.into_iter() + .cloned() + .map(|s| s.erase(session)) + .all(|s| s.write_free()) + }); + let checking_end_tries_are_write_free = tries .into_iter() - .cloned() - .map(|s| s.erase(session)) - .all(|s| s.write_free()); - ancestor_write_free && t.map_or(true, |t| t.state.clone().erase(session).write_free()) + .all(|t| t.map_or(true, |t| t.state.clone().erase(session).write_free())); + ancestor_write_free && checking_end_tries_are_write_free } // no writes in ancestor set, no writes in subtree #[allow(dead_code)] - fn can_acquire_read(&self, ptr: &JsonPointer, session: &HandleId) -> bool { - let (v, t) = self.ancestors_and_trie(ptr); - let ancestor_write_free = v + fn can_acquire_read(&self, ptr: &JsonGlob, session: &HandleId) -> bool { + let (vectors, tries): (Vec<_>, Vec<_>) = self.ancestors_and_trie(ptr).into_iter().unzip(); + let ancestor_write_free = vectors.into_iter().all(|v| { + v.into_iter() + .cloned() + .map(|s| s.erase(session)) + .all(|s| s.write_free()) + }); + let end_nodes_are_correct = tries .into_iter() - .cloned() - .map(|s| s.erase(session)) - .all(|s| s.write_free()); - ancestor_write_free && t.map_or(true, |t| t.subtree_is_exclusive_free_for(session)) + .all(|t| t.map_or(true, |t| t.subtree_is_exclusive_free_for(session))); + ancestor_write_free && end_nodes_are_correct } // no reads or writes in ancestor set, no locks in subtree #[allow(dead_code)] - fn can_acquire_write(&self, ptr: &JsonPointer, session: &HandleId) -> bool { - let (v, t) = self.ancestors_and_trie(ptr); - let ancestor_rw_free = v + fn can_acquire_write(&self, ptr: &JsonGlob, session: &HandleId) -> bool { + let (vectors, tries): (Vec<_>, Vec<_>) = self.ancestors_and_trie(ptr).into_iter().unzip(); + let ancestor_rw_free = vectors.into_iter().all(|v| { + v.into_iter() + .cloned() + .map(|s| s.erase(session)) + .all(|s| s.write_free() && s.read_free()) + }); + + let end_nodes_are_correct = tries .into_iter() - .cloned() - .map(|s| s.erase(session)) - .all(|s| s.write_free() && s.read_free()); - ancestor_rw_free && t.map_or(true, |t| t.subtree_is_lock_free_for(session)) + .all(|t| t.map_or(true, |t| t.subtree_is_lock_free_for(session))); + ancestor_rw_free && end_nodes_are_correct } // ancestors with writes and writes on the node fn session_blocking_exist<'a>( &'a self, - ptr: &JsonPointer, + ptr: &JsonGlob, session: &HandleId, ) -> Option<&'a HandleId> { - let (v, t) = self.ancestors_and_trie(ptr); - // there can only be one write session per traversal - let ancestor_write = v.into_iter().find_map(|s| s.write_session()); - let node_write = t.and_then(|t| t.state.write_session()); - ancestor_write - .or(node_write) - .and_then(|s| if s == session { None } else { Some(s) }) + let vectors_and_tries = self.ancestors_and_trie(ptr); + vectors_and_tries.into_iter().find_map(|(v, t)| { + // there can only be one write session per traversal + let ancestor_write = v.into_iter().find_map(|s| s.write_session()); + let node_write = t.and_then(|t| t.state.write_session()); + ancestor_write + .or(node_write) + .and_then(|s| if s == session { None } else { Some(s) }) + }) } // ancestors with writes, subtrees with writes fn sessions_blocking_read<'a>( &'a self, - ptr: &JsonPointer, + ptr: &JsonGlob, session: &HandleId, ) -> OrdSet<&'a HandleId> { - let (v, t) = self.ancestors_and_trie(ptr); - let ancestor_writes = v + let vectors_and_tries = self.ancestors_and_trie(ptr); + vectors_and_tries .into_iter() - .map(|s| s.write_session().into_iter().collect::>()) - .fold(OrdSet::new(), OrdSet::union); - let relevant_write_sessions = match t { - None => ancestor_writes, - Some(t) => ancestor_writes.union(t.subtree_write_sessions()), - }; - relevant_write_sessions.without(session) + .flat_map(|(v, t)| { + let ancestor_writes = v + .into_iter() + .map(|s| s.write_session().into_iter().collect::>()) + .fold(OrdSet::new(), OrdSet::union); + let relevant_write_sessions = match t { + None => ancestor_writes, + Some(t) => ancestor_writes.union(t.subtree_write_sessions()), + }; + relevant_write_sessions.without(session) + }) + .collect() } // ancestors with reads or writes, subtrees with anything fn sessions_blocking_write<'a>( &'a self, - ptr: &JsonPointer, + ptr: &JsonGlob, session: &HandleId, ) -> OrdSet<&'a HandleId> { - let (v, t) = self.ancestors_and_trie(ptr); - let ancestors = v + let vectors_and_tries = self.ancestors_and_trie(ptr); + vectors_and_tries .into_iter() - .map(|s| { - s.read_sessions() - .union(s.write_session().into_iter().collect()) + .flat_map(|(v, t)| { + let ancestors = v + .into_iter() + .map(|s| { + s.read_sessions() + .union(s.write_session().into_iter().collect()) + }) + .fold(OrdSet::new(), OrdSet::union); + let subtree = t.map_or(OrdSet::new(), |t| t.subtree_sessions()); + ancestors.union(subtree).without(session) }) - .fold(OrdSet::new(), OrdSet::union); - let subtree = t.map_or(OrdSet::new(), |t| t.subtree_sessions()); - ancestors.union(subtree).without(session) + .collect() } - fn child_mut, V: SegList>(&mut self, ptr: &JsonPointer) -> &mut Self { + fn child_mut_pointer, V: SegList>( + &mut self, + ptr: &JsonPointer, + ) -> &mut Self { match ptr.uncons() { None => self, Some((first, rest)) => { if !self.children.contains_key(first) { self.children.insert(first.to_owned(), LockTrie::default()); } - self.children.get_mut(first).unwrap().child_mut(&rest) + self.children + .get_mut(first) + .unwrap() + .child_mut_pointer(&rest) } } } + fn child_mut(&mut self, ptr: &JsonGlob) -> &mut Self { + match ptr { + JsonGlob::Path(x) => self.child_mut_pointer(x), + JsonGlob::PathWithStar(path) => self.child_mut_paths(path.segments()), + } + } + + fn child_mut_paths(&mut self, path: &[JsonGlobSegment]) -> &mut LockTrie { + let mut current = self; + let paths_iter = path.iter(); + for head in paths_iter { + let key = match head { + JsonGlobSegment::Path(path) => path.clone(), + JsonGlobSegment::Star => "*".to_string(), + }; + if !current.children.contains_key(&key) { + current.children.insert(key.to_owned(), LockTrie::default()); + } + current = current.children.get_mut(&key).unwrap(); + } + current + } + fn sessions_blocking_lock<'a>(&'a self, lock_info: &LockInfo) -> OrdSet<&'a HandleId> { match &lock_info.ty { LockType::Exist => self @@ -561,17 +650,23 @@ impl LockTrie { } } - pub fn try_lock<'a>(&'a mut self, lock_info: &LockInfo) -> Result<(), OrdSet> { - let blocking_sessions = self.sessions_blocking_lock(lock_info); + pub fn try_lock(&mut self, lock_infos: &LockInfos) -> Result<(), OrdSet> { + let lock_info_vec = lock_infos.as_vec(); + let blocking_sessions: OrdSet<_> = lock_info_vec + .iter() + .flat_map(|lock_info| self.sessions_blocking_lock(lock_info)) + .collect(); if !blocking_sessions.is_empty() { Err(blocking_sessions.into_iter().cloned().collect()) } else { drop(blocking_sessions); - let success = self - .child_mut(&lock_info.ptr) - .state - .try_lock(lock_info.handle_id.clone(), &lock_info.ty); - assert!(success); + for lock_info in lock_info_vec { + let success = self + .child_mut(&lock_info.ptr) + .state + .try_lock(lock_info.handle_id.clone(), &lock_info.ty); + assert!(success); + } Ok(()) } } diff --git a/patch-db/src/model.rs b/patch-db/src/model.rs index ca78a35..08b1b84 100644 --- a/patch-db/src/model.rs +++ b/patch-db/src/model.rs @@ -9,7 +9,11 @@ use json_ptr::JsonPointer; use serde::{Deserialize, Serialize}; use serde_json::Value; -use crate::locker::LockType; +use crate::{ + bulk_locks::{self, unsaturated_args::AsUnsaturatedArgs, LockTarget}, + locker::LockType, + model_paths::JsonGlob, +}; use crate::{DbHandle, DiffPatch, Error, Revision}; #[derive(Debug)] @@ -55,46 +59,91 @@ impl Deserialize<'de>> DerefMut for ModelDataMut { &mut self.current } } - #[derive(Debug)] pub struct Model Deserialize<'de>> { - ptr: JsonPointer, + pub(crate) path: JsonGlob, phantom: PhantomData, } + +lazy_static::lazy_static!( + static ref EMPTY_JSON: JsonPointer = JsonPointer::default(); +); + impl Model where T: Serialize + for<'de> Deserialize<'de>, { pub async fn lock(&self, db: &mut Db, lock_type: LockType) -> Result<(), Error> { - Ok(db.lock(self.ptr.clone(), lock_type).await?) + Ok(db.lock(self.json_ptr().clone().into(), lock_type).await?) } pub async fn get(&self, db: &mut Db, lock: bool) -> Result, Error> { if lock { self.lock(db, LockType::Read).await?; } - Ok(ModelData(db.get(&self.ptr).await?)) + Ok(ModelData(db.get(self.json_ptr()).await?)) } pub async fn get_mut(&self, db: &mut Db) -> Result, Error> { self.lock(db, LockType::Write).await?; - let original = db.get_value(&self.ptr, None).await?; + let original = db.get_value(self.json_ptr(), None).await?; let current = serde_json::from_value(original.clone())?; Ok(ModelDataMut { original, current, - ptr: self.ptr.clone(), + ptr: self.json_ptr().clone(), }) } - + /// Used for times of Serialization, or when going into the db + fn json_ptr(&self) -> &JsonPointer { + match self.path { + JsonGlob::Path(ref ptr) => ptr, + JsonGlob::PathWithStar { .. } => { + #[cfg(feature = "tracing")] + tracing::error!("Should be unreachable, since the type of () means that the paths is always Paths"); + &*EMPTY_JSON + } + } + } +} +impl Model +where + T: Serialize + for<'de> Deserialize<'de>, +{ pub fn child Deserialize<'de>>(self, index: &str) -> Model { - let mut ptr = self.ptr; - ptr.push_end(index); + let path = self.path.append(index.parse().unwrap_or_else(|_e| { + #[cfg(feature = "trace")] + tracing::error!("Shouldn't ever not be able to parse a path"); + Default::default() + })); Model { - ptr, + path, phantom: PhantomData, } } + + /// One use is gettign the modelPaths for the bulk locks + pub fn model_paths(&self) -> &JsonGlob { + &self.path + } +} + +impl Model +where + T: Serialize + for<'de> Deserialize<'de>, +{ + /// Used to create a lock for the db + pub fn make_locker(&self, lock_type: LockType) -> LockTarget + where + JsonGlob: AsUnsaturatedArgs, + { + bulk_locks::LockTarget { + lock_type, + db_type: self.phantom, + _star_binds: self.path.as_unsaturated_args(), + glob: self.path.clone(), + } + } } impl Model where @@ -106,7 +155,7 @@ where value: &T, ) -> Result>, Error> { self.lock(db, LockType::Write).await?; - db.put(&self.ptr, value).await + db.put(self.json_ptr(), value).await } } impl From for Model @@ -115,7 +164,18 @@ where { fn from(ptr: JsonPointer) -> Self { Self { - ptr, + path: JsonGlob::Path(ptr), + phantom: PhantomData, + } + } +} +impl From for Model +where + T: Serialize + for<'de> Deserialize<'de>, +{ + fn from(ptr: JsonGlob) -> Self { + Self { + path: ptr, phantom: PhantomData, } } @@ -125,7 +185,7 @@ where T: Serialize + for<'de> Deserialize<'de>, { fn as_ref(&self) -> &JsonPointer { - &self.ptr + self.json_ptr() } } impl From> for JsonPointer @@ -133,7 +193,15 @@ where T: Serialize + for<'de> Deserialize<'de>, { fn from(model: Model) -> Self { - model.ptr + model.json_ptr().clone() + } +} +impl From> for JsonGlob +where + T: Serialize + for<'de> Deserialize<'de>, +{ + fn from(model: Model) -> Self { + model.path } } impl std::clone::Clone for Model @@ -142,7 +210,7 @@ where { fn clone(&self) -> Self { Model { - ptr: self.ptr.clone(), + path: self.path.clone(), phantom: PhantomData, } } @@ -153,12 +221,24 @@ pub trait HasModel: Serialize + for<'de> Deserialize<'de> { } pub trait ModelFor Deserialize<'de>>: - From + AsRef + Into + From> + Clone + From + + From + + AsRef + + Into + + From> + + Clone + + Into { } impl< T: Serialize + for<'de> Deserialize<'de>, - U: From + AsRef + Into + From> + Clone, + U: From + + From + + AsRef + + Into + + From> + + Clone + + Into, > ModelFor for U { } @@ -167,6 +247,7 @@ macro_rules! impl_simple_has_model { ($($ty:ty),*) => { $( impl HasModel for $ty { + type Model = Model<$ty>; } )* @@ -208,11 +289,21 @@ impl Deserialize<'de>> From for BoxModel(T::Model::from(ptr)) } } +impl Deserialize<'de>> From for BoxModel { + fn from(ptr: JsonGlob) -> Self { + BoxModel(T::Model::from(ptr)) + } +} impl Deserialize<'de>> From> for JsonPointer { fn from(model: BoxModel) -> Self { model.0.into() } } +impl Deserialize<'de>> From> for JsonGlob { + fn from(model: BoxModel) -> Self { + model.0.into() + } +} impl std::clone::Clone for BoxModel where T: HasModel + Serialize + for<'de> Deserialize<'de>, @@ -229,7 +320,7 @@ impl Deserialize<'de>> HasModel for Box { pub struct OptionModel Deserialize<'de>>(T::Model); impl Deserialize<'de>> OptionModel { pub async fn lock(&self, db: &mut Db, lock_type: LockType) -> Result<(), Error> { - Ok(db.lock(self.0.as_ref().clone(), lock_type).await?) + Ok(db.lock(self.0.as_ref().clone().into(), lock_type).await?) } pub async fn get( @@ -259,9 +350,10 @@ impl Deserialize<'de>> OptionModel { pub async fn exists(&self, db: &mut Db, lock: bool) -> Result { if lock { - db.lock(self.0.as_ref().clone(), LockType::Exist).await?; + db.lock(self.0.as_ref().clone().into(), LockType::Exist) + .await?; } - Ok(db.exists(&self.as_ref(), None).await?) + Ok(db.exists(self.as_ref(), None).await?) } pub fn map< @@ -285,10 +377,30 @@ impl Deserialize<'de>> OptionModel { } pub async fn delete(&self, db: &mut Db) -> Result>, Error> { - db.lock(self.as_ref().clone(), LockType::Write).await?; + db.lock(self.as_ref().clone().into(), LockType::Write) + .await?; db.put(self.as_ref(), &Value::Null).await } } + +impl OptionModel +where + T: HasModel, +{ + /// Used to create a lock for the db + pub fn make_locker(self, lock_type: LockType) -> LockTarget + where + JsonGlob: AsUnsaturatedArgs, + { + let paths: JsonGlob = self.into(); + bulk_locks::LockTarget { + _star_binds: paths.as_unsaturated_args(), + glob: paths, + lock_type, + db_type: PhantomData, + } + } +} impl OptionModel where T: HasModel + Serialize + for<'de> Deserialize<'de>, @@ -319,7 +431,8 @@ where db: &mut Db, value: &T, ) -> Result>, Error> { - db.lock(self.as_ref().clone(), LockType::Write).await?; + db.lock(self.as_ref().clone().into(), LockType::Write) + .await?; db.put(self.as_ref(), value).await } } @@ -327,7 +440,7 @@ impl Deserialize<'de>> From>> for OptionModel { fn from(model: Model>) -> Self { - OptionModel(T::Model::from(JsonPointer::from(model))) + OptionModel(T::Model::from(JsonGlob::from(model))) } } impl Deserialize<'de>> From for OptionModel { @@ -335,11 +448,21 @@ impl Deserialize<'de>> From for OptionModel(T::Model::from(ptr)) } } +impl Deserialize<'de>> From for OptionModel { + fn from(ptr: JsonGlob) -> Self { + OptionModel(T::Model::from(ptr)) + } +} impl Deserialize<'de>> From> for JsonPointer { fn from(model: OptionModel) -> Self { model.0.into() } } +impl Deserialize<'de>> From> for JsonGlob { + fn from(model: OptionModel) -> Self { + model.0.into() + } +} impl Deserialize<'de>> AsRef for OptionModel { fn as_ref(&self) -> &JsonPointer { self.0.as_ref() @@ -382,7 +505,7 @@ impl Deserialize<'de>> VecModel { } impl Deserialize<'de>> From>> for VecModel { fn from(model: Model>) -> Self { - VecModel(From::from(JsonPointer::from(model))) + VecModel(From::from(JsonGlob::from(model))) } } impl Deserialize<'de>> From for VecModel { @@ -390,11 +513,21 @@ impl Deserialize<'de>> From for VecModel VecModel(From::from(ptr)) } } +impl Deserialize<'de>> From for VecModel { + fn from(ptr: JsonGlob) -> Self { + VecModel(From::from(ptr)) + } +} impl Deserialize<'de>> From> for JsonPointer { fn from(model: VecModel) -> Self { model.0.into() } } +impl Deserialize<'de>> From> for JsonGlob { + fn from(model: VecModel) -> Self { + model.0.into() + } +} impl AsRef for VecModel where T: Serialize + for<'de> Deserialize<'de>, @@ -490,16 +623,18 @@ where lock: bool, ) -> Result, Error> { if lock { - db.lock(self.as_ref().clone(), LockType::Exist).await?; + db.lock(self.json_ptr().clone().into(), LockType::Exist) + .await?; } - let set = db.keys(self.as_ref(), None).await?; + let set = db.keys(self.json_ptr(), None).await?; Ok(set .into_iter() .map(|s| serde_json::from_value(Value::String(s))) .collect::>()?) } pub async fn remove(&self, db: &mut Db, key: &T::Key) -> Result<(), Error> { - db.lock(self.as_ref().clone(), LockType::Write).await?; + db.lock(self.as_ref().clone().into(), LockType::Write) + .await?; if db.exists(self.clone().idx(key).as_ref(), None).await? { db.apply( DiffPatch(Patch(vec![PatchOperation::Remove(RemoveOperation { @@ -521,6 +656,22 @@ where self.0.child(idx.as_ref()).into() } } + +impl MapModel +where + T: Serialize + for<'de> Deserialize<'de> + Map, + T::Value: Serialize + for<'de> Deserialize<'de> + HasModel, +{ + /// Used when mapping across all possible paths of a map or such, to later be filled + pub fn star(self) -> <::Value as HasModel>::Model { + let path = self.0.path.append(JsonGlob::star()); + Model { + path, + phantom: PhantomData, + } + .into() + } +} impl From> for MapModel where T: Serialize + for<'de> Deserialize<'de> + Map, @@ -539,6 +690,15 @@ where MapModel(From::from(ptr)) } } +impl From for MapModel +where + T: Serialize + for<'de> Deserialize<'de> + Map, + T::Value: Serialize + for<'de> Deserialize<'de>, +{ + fn from(ptr: JsonGlob) -> Self { + MapModel(From::from(ptr)) + } +} impl From> for JsonPointer where T: Serialize + for<'de> Deserialize<'de> + Map, @@ -548,6 +708,15 @@ where model.0.into() } } +impl From> for JsonGlob +where + T: Serialize + for<'de> Deserialize<'de> + Map, + T::Value: Serialize + for<'de> Deserialize<'de>, +{ + fn from(model: MapModel) -> Self { + model.0.into() + } +} impl AsRef for MapModel where T: Serialize + for<'de> Deserialize<'de> + Map, diff --git a/patch-db/src/model_paths.rs b/patch-db/src/model_paths.rs new file mode 100644 index 0000000..3b05a11 --- /dev/null +++ b/patch-db/src/model_paths.rs @@ -0,0 +1,419 @@ +use std::str::FromStr; + +use json_ptr::JsonPointer; + +/// Used in the locking of a model where we have an all, a predicate to filter children. +/// This is split because we know the path or we have predicate filters +/// We split once we got the all, so we could go into the models and lock all of services.name for example +/// without locking all of them. +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] +pub enum JsonGlobSegment { + /// Used to be just be a regular json path + Path(String), + /// Indicating that we are going to be using some part of all of this Vec, Map, etc. + Star, +} +impl std::fmt::Display for JsonGlobSegment { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + JsonGlobSegment::Path(x) => { + write!(f, "{}", x)?; + } + JsonGlobSegment::Star => { + write!(f, "*")?; + } + } + Ok(()) + } +} + +/// Use in the model to point from root down a specific path +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] +pub enum JsonGlob { + /// This was the default + Path(JsonPointer), + /// Once we add an All, our predicate, we don't know the possible paths could be in the maps so we are filling + /// in binds for the possible paths to take. + PathWithStar(PathWithStar), +} + +/// Path including the glob +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)] +pub struct PathWithStar { + segments: Vec, + count: usize, +} + +impl PathWithStar { + pub fn segments(&self) -> &[JsonGlobSegment] { + &self.segments + } + pub fn count(&self) -> usize { + self.count + } +} + +impl Default for JsonGlob { + fn default() -> Self { + Self::Path(Default::default()) + } +} + +impl From for JsonGlob { + fn from(pointer: JsonPointer) -> Self { + Self::Path(pointer) + } +} + +impl std::fmt::Display for JsonGlob { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + JsonGlob::Path(x) => { + write!(f, "{}", x)?; + } + JsonGlob::PathWithStar(PathWithStar { + segments: path, + count: _, + }) => { + for path in path.iter() { + write!(f, "/")?; + write!(f, "{}", path)?; + } + } + } + Ok(()) + } +} + +impl FromStr for JsonGlob { + type Err = String; + + fn from_str(s: &str) -> Result { + let split = s.split('/').filter(|x| !x.is_empty()); + if !s.contains('*') { + return Ok(JsonGlob::Path(split.fold( + JsonPointer::default(), + |mut pointer, s| { + pointer.push_end(s); + pointer + }, + ))); + } + let segments: Vec = split + .map(|x| match x { + "*" => JsonGlobSegment::Star, + x => JsonGlobSegment::Path(x.to_string()), + }) + .collect(); + let segments = segments; + let count = segments + .iter() + .filter(|x| matches!(x, JsonGlobSegment::Star)) + .count(); + Ok(JsonGlob::PathWithStar(PathWithStar { segments, count })) + } +} + +impl JsonGlob { + pub fn append(self, path: JsonGlob) -> Self { + fn append_stars( + PathWithStar { + segments: mut left_segments, + count: left_count, + }: PathWithStar, + PathWithStar { + segments: mut right_segments, + count: right_count, + }: PathWithStar, + ) -> PathWithStar { + left_segments.append(&mut right_segments); + + PathWithStar { + segments: left_segments, + count: left_count + right_count, + } + } + + fn point_as_path_with_star(pointer: JsonPointer) -> PathWithStar { + PathWithStar { + segments: pointer.into_iter().map(JsonGlobSegment::Path).collect(), + count: 0, + } + } + + match (self, path) { + (JsonGlob::Path(mut paths), JsonGlob::Path(right_paths)) => { + paths.append(&right_paths); + JsonGlob::Path(paths) + } + (JsonGlob::Path(left), JsonGlob::PathWithStar(right)) => { + JsonGlob::PathWithStar(append_stars(point_as_path_with_star(left), right)) + } + (JsonGlob::PathWithStar(left), JsonGlob::Path(right)) => { + JsonGlob::PathWithStar(append_stars(left, point_as_path_with_star(right))) + } + (JsonGlob::PathWithStar(left), JsonGlob::PathWithStar(right)) => { + JsonGlob::PathWithStar(append_stars(left, right)) + } + } + } + + /// Used during the creation of star paths + pub fn star() -> Self { + JsonGlob::PathWithStar(PathWithStar { + segments: vec![JsonGlobSegment::Star], + count: 1, + }) + } + + /// There are points that we use the JsonPointer starts_with, and we need to be able to + /// utilize that and to be able to deal with the star paths + pub fn starts_with(&self, other: &JsonGlob) -> bool { + fn starts_with_<'a>(left: &Vec, right: &Vec) -> bool { + let mut left_paths = left.iter(); + let mut right_paths = right.iter(); + loop { + match (left_paths.next(), right_paths.next()) { + (Some(JsonGlobSegment::Path(x)), Some(JsonGlobSegment::Path(y))) => { + if x != y { + return false; + } + } + (Some(JsonGlobSegment::Star), Some(JsonGlobSegment::Star)) => {} + (Some(JsonGlobSegment::Star), Some(JsonGlobSegment::Path(_))) => {} + (Some(JsonGlobSegment::Path(_)), Some(JsonGlobSegment::Star)) => {} + (None, None) => return true, + (None, _) => return false, + (_, None) => return true, + } + } + } + match (self, other) { + (JsonGlob::Path(x), JsonGlob::Path(y)) => x.starts_with(y), + ( + JsonGlob::Path(x), + JsonGlob::PathWithStar(PathWithStar { + segments: path, + count: _, + }), + ) => starts_with_( + &x.iter() + .map(|x| JsonGlobSegment::Path(x.to_string())) + .collect(), + path, + ), + ( + JsonGlob::PathWithStar(PathWithStar { + segments: path, + count: _, + }), + JsonGlob::Path(y), + ) => starts_with_( + path, + &y.iter() + .map(|x| JsonGlobSegment::Path(x.to_string())) + .collect(), + ), + ( + JsonGlob::PathWithStar(PathWithStar { + segments: path, + count: _, + }), + JsonGlob::PathWithStar(PathWithStar { + segments: path_other, + count: _, + }), + ) => starts_with_(path, path_other), + } + } + /// When we need to convert back into a usuable pointer string that is used for the paths of the + /// get and set of the db. + pub fn as_pointer(&self, binds: &[&str]) -> JsonPointer { + match self { + JsonGlob::Path(json_pointer) => json_pointer.clone(), + JsonGlob::PathWithStar(PathWithStar { + segments: path, + count: _, + }) => { + let mut json_pointer: JsonPointer = Default::default(); + let mut binds = binds.iter(); + for path in (*path).iter() { + match path { + JsonGlobSegment::Path(path) => json_pointer.push_end(&path), + JsonGlobSegment::Star => { + if let Some(path) = binds.next() { + json_pointer.push_end(path) + } + } + } + } + json_pointer + } + } + } + + pub fn star_count(&self) -> usize { + match self { + JsonGlob::Path(_) => 0, + JsonGlob::PathWithStar(PathWithStar { count, .. }) => *count, + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + use proptest::prelude::*; + + #[test] + fn model_paths_parse_simple() { + let path = "/a/b/c"; + let model_paths = JsonGlob::from_str(path).unwrap(); + assert_eq!( + model_paths.as_pointer(&[]), + JsonPointer::from_str(path).unwrap() + ); + } + #[test] + fn model_paths_parse_star() { + let path = "/a/b/c/*/e"; + let model_paths = JsonGlob::from_str(path).unwrap(); + assert_eq!( + model_paths.as_pointer(&["d"]), + JsonPointer::from_str("/a/b/c/d/e").unwrap() + ); + } + + #[test] + fn append() { + let path = "/a/b/"; + let model_paths = JsonGlob::from_str(path) + .unwrap() + .append("c".parse().unwrap()); + assert_eq!( + model_paths.as_pointer(&[]), + JsonPointer::from_str("/a/b/c").unwrap() + ); + } + #[test] + fn append_star() { + let path = "/a/b/"; + let model_paths = JsonGlob::from_str(path) + .unwrap() + .append("*".parse().unwrap()); + assert_eq!( + model_paths.as_pointer(&["c"]), + JsonPointer::from_str("/a/b/c").unwrap() + ); + } + #[test] + fn star_append() { + let path = "/a/*/"; + let model_paths = JsonGlob::from_str(path) + .unwrap() + .append("c".parse().unwrap()); + assert_eq!( + model_paths.as_pointer(&["b"]), + JsonPointer::from_str("/a/b/c").unwrap() + ); + } + #[test] + fn star_append_star() { + let path = "/a/*/"; + let model_paths = JsonGlob::from_str(path) + .unwrap() + .append("*".parse().unwrap()); + assert_eq!( + model_paths.as_pointer(&["b", "c"]), + JsonPointer::from_str("/a/b/c").unwrap() + ); + } + #[test] + fn starts_with_paths() { + let path: JsonGlob = "/a/b".parse().unwrap(); + let path_b: JsonGlob = "/a".parse().unwrap(); + let path_c: JsonGlob = "/a/b/c".parse().unwrap(); + assert!(path.starts_with(&path_b)); + assert!(!path.starts_with(&path_c)); + assert!(path_c.starts_with(&path)); + assert!(!path_b.starts_with(&path)); + } + #[test] + fn starts_with_star_left() { + let path: JsonGlob = "/a/*/c".parse().unwrap(); + let path_a: JsonGlob = "/a".parse().unwrap(); + let path_b: JsonGlob = "/b".parse().unwrap(); + let path_full_c: JsonGlob = "/a/b/c".parse().unwrap(); + let path_full_c_d: JsonGlob = "/a/b/c/d".parse().unwrap(); + let path_full_d_other: JsonGlob = "/a/b/d".parse().unwrap(); + assert!(path.starts_with(&path_a)); + assert!(path.starts_with(&path)); + assert!(!path.starts_with(&path_b)); + assert!(path.starts_with(&path_full_c)); + assert!(!path.starts_with(&path_full_c_d)); + assert!(!path.starts_with(&path_full_d_other)); + + // Others start with + assert!(!path_a.starts_with(&path)); + assert!(!path_b.starts_with(&path)); + assert!(path_full_c.starts_with(&path)); + assert!(path_full_c_d.starts_with(&path)); + assert!(!path_full_d_other.starts_with(&path)); + } + + /// A valid star path is something like `/a/*/c` + /// A path may start with a letter, then any letter/ dash/ number + /// A star path may only be a star + pub fn arb_path_str() -> impl Strategy { + // Funny enough we can't test the max size, running out of memory, funny that + proptest::collection::vec("([a-z][a-z\\-0-9]*|\\*)", 0..100).prop_map(|a_s| { + a_s.into_iter().fold(String::new(), |mut s, x| { + s.push('/'); + s.push_str(&x); + s + }) + }) + } + + mod star_counts { + use super::*; + #[test] + fn base_have_valid_star_count() { + let path = "/a/*/c"; + let glob = JsonGlob::from_str(&path).unwrap(); + assert_eq!( + glob.star_count(), + 1, + "Star count should be the total number of star paths for path {}", + path + ); + } + proptest! { + #[test] + fn all_valid_paths_have_valid_star_count(path in arb_path_str()) { + let glob = JsonGlob::from_str(&path).unwrap(); + prop_assert_eq!(glob.star_count(), path.matches('*').count(), "Star count should be the total number of star paths for path {}", path); + } + } + } + + proptest! { + #[test] + fn inductive_append_as_monoid(left in arb_path_str(), right in arb_path_str()) { + let left_glob = JsonGlob::from_str(&left).unwrap(); + let right_glob = JsonGlob::from_str(&right).unwrap(); + let expected_join = format!("{}{}", left, right); + let expected = JsonGlob::from_str(&expected_join).unwrap(); + let answer = left_glob.append(right_glob); + prop_assert_eq!(answer, expected, "Appending another path should be the same as joining them as a string first for path {}", expected_join); + } + + #[test] + fn all_globs_parse_display_isomorphism(path in arb_path_str()) { + let glob = JsonGlob::from_str(&path).unwrap(); + let other_glob = JsonGlob::from_str(&glob.to_string()).unwrap(); + prop_assert_eq!(other_glob, glob); + } + } +} diff --git a/patch-db/src/test.rs b/patch-db/src/test.rs index b39ef37..279249f 100644 --- a/patch-db/src/test.rs +++ b/patch-db/src/test.rs @@ -109,8 +109,8 @@ async fn locks_dropped_from_enforcer_on_tx_save() { let mut tx = handle.begin().await.unwrap(); let ptr_a: JsonPointer = "/a".parse().unwrap(); let ptr_b: JsonPointer = "/b".parse().unwrap(); - tx.lock(ptr_b, LockType::Write).await.unwrap(); + tx.lock(ptr_b.into(), LockType::Write).await.unwrap(); tx.save().await.unwrap(); - handle.lock(ptr_a, LockType::Write).await.unwrap(); + handle.lock(ptr_a.into(), LockType::Write).await.unwrap(); cleanup_db("test.db").await; } diff --git a/patch-db/src/transaction.rs b/patch-db/src/transaction.rs index 1fcec18..bf3dd9f 100644 --- a/patch-db/src/transaction.rs +++ b/patch-db/src/transaction.rs @@ -9,10 +9,13 @@ use tokio::sync::broadcast::error::TryRecvError; use tokio::sync::broadcast::Receiver; use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; -use crate::handle::HandleId; -use crate::locker::{Guard, LockType, Locker}; use crate::patch::{DiffPatch, Revision}; use crate::store::Store; +use crate::{ + bulk_locks::Verifier, + locker::{Guard, LockType, Locker}, +}; +use crate::{handle::HandleId, model_paths::JsonGlob}; use crate::{DbHandle, Error, PatchDbHandle}; pub struct Transaction { @@ -165,13 +168,14 @@ impl DbHandle for Transaction { self.updates.append(patch); Ok(None) } - async fn lock(&mut self, ptr: JsonPointer, lock_type: LockType) -> Result<(), Error> { - Ok(self.locks.push( + async fn lock(&mut self, ptr: JsonGlob, lock_type: LockType) -> Result<(), Error> { + self.locks.push( self.parent .locker() .lock(self.id.clone(), ptr, lock_type) .await?, - )) + ); + Ok(()) } async fn get< T: for<'de> Deserialize<'de>, @@ -202,4 +206,16 @@ impl DbHandle for Transaction { self.updates.append(patch); Ok(None) } + + async fn lock_all<'a>( + &'a mut self, + locks: impl IntoIterator + Send + Clone + 'a, + ) -> Result { + let verifier = Verifier { + target_locks: locks.clone().into_iter().collect(), + }; + self.locks + .push(self.parent.locker().lock_all(&self.id, locks).await?); + Ok(verifier) + } }