fix locking logic

This commit is contained in:
Aiden McClelland
2021-03-09 19:30:56 -07:00
parent 4dbff195d5
commit c2e50d0e88
7 changed files with 305 additions and 44 deletions

6
Cargo.toml Normal file
View File

@@ -0,0 +1,6 @@
[workspace]
members = [
"patch-db",
"patch-db-derive",
"patch-db-derive-internals",
]

View File

@@ -0,0 +1,15 @@
[package]
name = "patch-db-derive-internals"
version = "0.1.0"
authors = ["Aiden McClelland <me@drbonez.dev>"]
edition = "2018"
description = "internals for derive macros for defining typed patch dbs"
license = "MIT"
repository = "https://github.com/dr-bonez/patch-db"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
syn = { version = "1.0.5", features = ["full", "extra-traits"] }
quote = "1.0.1"
proc-macro2 = "1.0.1"

View File

@@ -0,0 +1,42 @@
use proc_macro2::TokenStream;
use quote::quote;
pub fn build_model(input: &syn::DeriveInput) -> TokenStream {
match &input.data {
syn::Data::Struct(struct_ast) => build_model_struct(input, struct_ast),
syn::Data::Enum(enum_ast) => build_model_enum(enum_ast),
syn::Data::Union(_) => panic!("Unions are not supported"),
}
}
fn build_model_struct(input: &syn::DeriveInput, ast: &syn::DataStruct) -> TokenStream {
let model_name = syn::Ident::new(
&format!("{}Model", input.ident),
proc_macro2::Span::call_site(),
);
let base_name = &input.ident;
let model_vis = &input.vis;
quote! {
#model_vis struct #model_name<Tx: patch_db::Checkpoint> {
data: Option<Box<#base_name>>,
ptr: json_ptr::JsonPointer,
tx: Tx,
}
impl<Tx: patch_db::Checkpoint> #model_name<Tx> {
pub fn get(&mut self, lock: patch_db::LockType) -> Result<&#base_name, patch_db::Error> {
if let Some(data) = self.data.as_ref() {
match lock {
patch_db::LockType::None => Ok(data),
}
} else {
self.tx.get(&self.ptr, lock)
}
}
}
}
}
fn build_model_enum(ast: &syn::DataEnum) -> TokenStream {
todo!()
}

View File

@@ -0,0 +1,16 @@
[package]
name = "patch-db-derive"
version = "0.1.0"
authors = ["Aiden McClelland <me@drbonez.dev>"]
edition = "2018"
description = "derive macros for defining typed patch dbs"
license = "MIT"
repository = "https://github.com/dr-bonez/patch-db"
[lib]
proc-macro = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
patch-db-derive-internals = { path = "../patch-db-derive-internals" }
syn = "1.0.62"

View File

@@ -0,0 +1,13 @@
extern crate proc_macro;
use proc_macro::TokenStream;
#[proc_macro_derive(Model, attributes(serde))]
pub fn model_derive(input: TokenStream) -> TokenStream {
// Construct a representation of Rust code as a syntax tree
// that we can manipulate
let ast = syn::parse(input).unwrap();
// Build the trait implementation
patch_db_derive_internals::build_model(&ast).into()
}

View File

@@ -12,7 +12,7 @@ fd-lock-rs = "0.1.3"
futures = "0.3.8" futures = "0.3.8"
json-patch = { path = "../../json-patch" } json-patch = { path = "../../json-patch" }
json-ptr = { path = "../../json-ptr" } json-ptr = { path = "../../json-ptr" }
nix = "0.19.1" nix = "0.20.0"
qutex-2 = { path = "../../qutex" } qutex-2 = { path = "../../qutex" }
serde = { version = "1.0.118", features = ["rc"] } serde = { version = "1.0.118", features = ["rc"] }
serde_json = "1.0.61" serde_json = "1.0.61"

View File

@@ -16,7 +16,7 @@ use tokio::{
fs::File, fs::File,
sync::{ sync::{
broadcast::{Receiver, Sender}, broadcast::{Receiver, Sender},
RwLock, Mutex, RwLock,
}, },
}; };
@@ -228,9 +228,28 @@ pub trait Checkpoint {
&'a self, &'a self,
ptr: &'a JsonPointer<S, V>, ptr: &'a JsonPointer<S, V>,
) -> BoxFuture<Result<Value, Error>>; ) -> BoxFuture<Result<Value, Error>>;
fn locks(&self) -> &[(JsonPointer, LockerGuard)]; fn locker_and_locks(&mut self) -> (&Locker, Vec<&mut [(JsonPointer, LockerGuard)]>);
fn locker(&self) -> &Locker;
fn apply(&mut self, patch: DiffPatch); fn apply(&mut self, patch: DiffPatch);
fn get<
'a,
T: for<'de> Deserialize<'de> + 'a,
S: AsRef<str> + Clone + Send + Sync + 'a,
V: SegList + Clone + Send + Sync + 'a,
>(
&'a mut self,
ptr: &'a JsonPointer<S, V>,
lock: LockType,
) -> BoxFuture<'a, Result<T, Error>>;
fn put<
'a,
T: Serialize + Send + Sync + 'a,
S: AsRef<str> + Send + Sync + 'a,
V: SegList + Send + Sync + 'a,
>(
&'a mut self,
ptr: &'a JsonPointer<S, V>,
value: &'a T,
) -> BoxFuture<'a, Result<(), Error>>;
} }
pub struct Transaction { pub struct Transaction {
@@ -276,8 +295,18 @@ impl Transaction {
) { ) {
match lock { match lock {
LockType::None => (), LockType::None => (),
LockType::Read => self.db.locker.add_read_lock(ptr, &mut self.locks).await, LockType::Read => {
LockType::Write => self.db.locker.add_write_lock(ptr, &mut self.locks).await, self.db
.locker
.add_read_lock(ptr, &mut self.locks, &mut [])
.await
}
LockType::Write => {
self.db
.locker
.add_write_lock(ptr, &mut self.locks, &mut [])
.await
}
} }
} }
pub async fn get<T: for<'de> Deserialize<'de>, S: AsRef<str> + Clone, V: SegList + Clone>( pub async fn get<T: for<'de> Deserialize<'de>, S: AsRef<str> + Clone, V: SegList + Clone>(
@@ -292,7 +321,7 @@ impl Transaction {
} }
pub async fn put<T: Serialize, S: AsRef<str>, V: SegList>( pub async fn put<T: Serialize, S: AsRef<str>, V: SegList>(
&mut self, &mut self,
ptr: &JsonPointer<S>, ptr: &JsonPointer<S, V>,
value: &T, value: &T,
) -> Result<(), Error> { ) -> Result<(), Error> {
let old = Transaction::get_value(self, ptr).await?; let old = Transaction::get_value(self, ptr).await?;
@@ -314,15 +343,36 @@ impl<'a> Checkpoint for &'a mut Transaction {
) -> BoxFuture<'b, Result<Value, Error>> { ) -> BoxFuture<'b, Result<Value, Error>> {
Transaction::get_value(self, ptr).boxed() Transaction::get_value(self, ptr).boxed()
} }
fn locks(&self) -> &[(JsonPointer, LockerGuard)] { fn locker_and_locks(&mut self) -> (&Locker, Vec<&mut [(JsonPointer, LockerGuard)]>) {
&self.locks (&self.db.locker, vec![&mut self.locks])
}
fn locker(&self) -> &Locker {
&self.db.locker
} }
fn apply(&mut self, patch: DiffPatch) { fn apply(&mut self, patch: DiffPatch) {
(self.updates.0).0.extend((patch.0).0) (self.updates.0).0.extend((patch.0).0)
} }
fn get<
'b,
T: for<'de> Deserialize<'de> + 'b,
S: AsRef<str> + Clone + Send + Sync + 'b,
V: SegList + Clone + Send + Sync + 'b,
>(
&'b mut self,
ptr: &'b JsonPointer<S, V>,
lock: LockType,
) -> BoxFuture<'b, Result<T, Error>> {
Transaction::get(self, ptr, lock).boxed()
}
fn put<
'b,
T: Serialize + Send + Sync + 'b,
S: AsRef<str> + Send + Sync + 'b,
V: SegList + Send + Sync + 'b,
>(
&'b mut self,
ptr: &'b JsonPointer<S, V>,
value: &'b T,
) -> BoxFuture<'b, Result<(), Error>> {
Transaction::put(self, ptr, value).boxed()
}
} }
pub struct SubTransaction<Tx: Checkpoint> { pub struct SubTransaction<Tx: Checkpoint> {
@@ -363,23 +413,16 @@ impl<Tx: Checkpoint> SubTransaction<Tx> {
ptr: &JsonPointer<S, V>, ptr: &JsonPointer<S, V>,
lock: LockType, lock: LockType,
) { ) {
for lock in self.locks.iter() {
if ptr.starts_with(&lock.0) {
return;
}
}
match lock { match lock {
LockType::None => (), LockType::None => (),
LockType::Read => { LockType::Read => {
self.parent let (locker, mut locks) = self.parent.locker_and_locks();
.locker() locker.add_read_lock(ptr, &mut self.locks, &mut locks).await
.add_read_lock(ptr, &mut self.locks)
.await
} }
LockType::Write => { LockType::Write => {
self.parent let (locker, mut locks) = self.parent.locker_and_locks();
.locker() locker
.add_write_lock(ptr, &mut self.locks) .add_write_lock(ptr, &mut self.locks, &mut locks)
.await .await
} }
} }
@@ -400,7 +443,7 @@ impl<Tx: Checkpoint> SubTransaction<Tx> {
} }
pub async fn put<T: Serialize, S: AsRef<str> + Send + Sync, V: SegList + Send + Sync>( pub async fn put<T: Serialize, S: AsRef<str> + Send + Sync, V: SegList + Send + Sync>(
&mut self, &mut self,
ptr: &JsonPointer<S>, ptr: &JsonPointer<S, V>,
value: &T, value: &T,
) -> Result<(), Error> { ) -> Result<(), Error> {
let old = SubTransaction::get_value(self, ptr).await?; let old = SubTransaction::get_value(self, ptr).await?;
@@ -422,15 +465,38 @@ impl<'a, Tx: Checkpoint + Send + Sync> Checkpoint for &'a mut SubTransaction<Tx>
) -> BoxFuture<'b, Result<Value, Error>> { ) -> BoxFuture<'b, Result<Value, Error>> {
SubTransaction::get_value(self, ptr).boxed() SubTransaction::get_value(self, ptr).boxed()
} }
fn locks(&self) -> &[(JsonPointer, LockerGuard)] { fn locker_and_locks(&mut self) -> (&Locker, Vec<&mut [(JsonPointer, LockerGuard)]>) {
&self.locks let (locker, mut locks) = self.parent.locker_and_locks();
} locks.push(&mut self.locks);
fn locker(&self) -> &Locker { (locker, locks)
&self.parent.locker()
} }
fn apply(&mut self, patch: DiffPatch) { fn apply(&mut self, patch: DiffPatch) {
(self.updates.0).0.extend((patch.0).0) (self.updates.0).0.extend((patch.0).0)
} }
fn get<
'b,
T: for<'de> Deserialize<'de> + 'b,
S: AsRef<str> + Clone + Send + Sync + 'b,
V: SegList + Clone + Send + Sync + 'b,
>(
&'b mut self,
ptr: &'b JsonPointer<S, V>,
lock: LockType,
) -> BoxFuture<'b, Result<T, Error>> {
SubTransaction::get(self, ptr, lock).boxed()
}
fn put<
'b,
T: Serialize + Send + Sync + 'b,
S: AsRef<str> + Send + Sync + 'b,
V: SegList + Send + Sync + 'b,
>(
&'b mut self,
ptr: &'b JsonPointer<S, V>,
value: &'b T,
) -> BoxFuture<'b, Result<(), Error>> {
SubTransaction::put(self, ptr, value).boxed()
}
} }
#[derive(Debug)] #[derive(Debug)]
@@ -442,8 +508,8 @@ pub enum LockType {
pub enum LockerGuard { pub enum LockerGuard {
Empty, Empty,
Read(ReadGuard<HashMap<String, Locker>>), Read(LockerReadGuard),
Write(WriteGuard<HashMap<String, Locker>>), Write(LockerWriteGuard),
} }
impl LockerGuard { impl LockerGuard {
pub fn take(&mut self) -> Self { pub fn take(&mut self) -> Self {
@@ -451,6 +517,44 @@ impl LockerGuard {
} }
} }
#[derive(Debug, Clone)]
pub struct LockerReadGuard(Arc<Mutex<Option<ReadGuard<HashMap<String, Locker>>>>>);
impl LockerReadGuard {
async fn upgrade(&self) -> Option<LockerWriteGuard> {
let guard = self.0.try_lock().unwrap().take();
if let Some(g) = guard {
Some(LockerWriteGuard(
Some(ReadGuard::upgrade(g).await.unwrap()),
Some(self.clone()),
))
} else {
None
}
}
}
impl From<ReadGuard<HashMap<String, Locker>>> for LockerReadGuard {
fn from(guard: ReadGuard<HashMap<String, Locker>>) -> Self {
LockerReadGuard(Arc::new(Mutex::new(Some(guard))))
}
}
pub struct LockerWriteGuard(
Option<WriteGuard<HashMap<String, Locker>>>,
Option<LockerReadGuard>,
);
impl From<WriteGuard<HashMap<String, Locker>>> for LockerWriteGuard {
fn from(guard: WriteGuard<HashMap<String, Locker>>) -> Self {
LockerWriteGuard(Some(guard), None)
}
}
impl Drop for LockerWriteGuard {
fn drop(&mut self) {
if let (Some(write), Some(read)) = (self.0.take(), self.1.take()) {
*read.0.try_lock().unwrap() = Some(WriteGuard::downgrade(write));
}
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Locker(QrwLock<HashMap<String, Locker>>); pub struct Locker(QrwLock<HashMap<String, Locker>>);
impl Locker { impl Locker {
@@ -475,24 +579,29 @@ impl Locker {
} }
lock.unwrap() lock.unwrap()
} }
pub async fn add_read_lock<S: AsRef<str> + Clone, V: SegList + Clone>( async fn add_read_lock<S: AsRef<str> + Clone, V: SegList + Clone>(
&self, &self,
ptr: &JsonPointer<S, V>, ptr: &JsonPointer<S, V>,
locks: &mut Vec<(JsonPointer, LockerGuard)>, locks: &mut Vec<(JsonPointer, LockerGuard)>,
extra_locks: &mut [&mut [(JsonPointer, LockerGuard)]],
) { ) {
for lock in locks.iter() { for lock in extra_locks
.iter()
.flat_map(|a| a.iter())
.chain(locks.iter())
{
if ptr.starts_with(&lock.0) { if ptr.starts_with(&lock.0) {
return; return;
} }
} }
locks.push(( locks.push((
JsonPointer::to_owned(ptr.clone()), JsonPointer::to_owned(ptr.clone()),
LockerGuard::Read(self.lock_read(ptr).await), LockerGuard::Read(self.lock_read(ptr).await.into()),
)); ));
} }
pub async fn lock_write<S: AsRef<str>, V: SegList>( pub async fn lock_write<S: AsRef<str>, V: SegList>(
&self, &self,
ptr: &JsonPointer<S>, ptr: &JsonPointer<S, V>,
) -> WriteGuard<HashMap<String, Locker>> { ) -> WriteGuard<HashMap<String, Locker>> {
let mut lock = self.0.clone().write().await.unwrap(); let mut lock = self.0.clone().write().await.unwrap();
for seg in ptr.iter() { for seg in ptr.iter() {
@@ -506,26 +615,86 @@ impl Locker {
} }
lock lock
} }
pub async fn add_write_lock<S: AsRef<str> + Clone, V: SegList + Clone>( async fn add_write_lock<S: AsRef<str> + Clone, V: SegList + Clone>(
&self, &self,
ptr: &JsonPointer<S, V>, ptr: &JsonPointer<S, V>,
locks: &mut Vec<(JsonPointer, LockerGuard)>, locks: &mut Vec<(JsonPointer, LockerGuard)>,
extra_locks: &mut [&mut [(JsonPointer, LockerGuard)]],
) { ) {
for lock in locks.iter_mut() { let mut final_lock = None;
if ptr.starts_with(&lock.0) { for lock in extra_locks
.iter_mut()
.flat_map(|a| a.iter_mut())
.chain(locks.iter_mut())
{
enum Choice {
Return,
Continue,
Break,
}
let choice: Choice;
if let Some(remainder) = ptr.strip_prefix(&lock.0) {
let guard = lock.1.take(); let guard = lock.1.take();
lock.1 = match guard { lock.1 = match guard {
LockerGuard::Read(l) => { LockerGuard::Read(LockerReadGuard(guard)) if !remainder.is_empty() => {
LockerGuard::Write(ReadGuard::upgrade(l).await.unwrap()) // read guard already exists at higher level
let mut lock = guard.lock().await;
if let Some(l) = lock.take() {
let mut orig_lock = None;
let mut lock = ReadGuard::upgrade(l).await.unwrap();
for seg in remainder.iter() {
let new_lock = if let Some(locker) = lock.get(seg) {
locker.0.clone().write().await.unwrap()
} else {
lock.insert(seg.to_owned(), Locker::new());
lock.get(seg).unwrap().0.clone().write().await.unwrap()
};
if orig_lock.is_none() {
orig_lock = Some(lock);
}
lock = new_lock;
}
final_lock = Some(LockerGuard::Write(lock.into()));
choice = Choice::Break;
LockerGuard::Read(WriteGuard::downgrade(orig_lock.unwrap()).into())
} else {
drop(lock);
choice = Choice::Return;
LockerGuard::Read(LockerReadGuard(guard))
}
}
LockerGuard::Read(l) => {
// read exists, convert to write
if let Some(upgraded) = l.upgrade().await {
final_lock = Some(LockerGuard::Write(upgraded));
choice = Choice::Break;
} else {
choice = Choice::Continue;
}
LockerGuard::Read(l)
}
LockerGuard::Write(l) => {
choice = Choice::Return;
LockerGuard::Write(l)
} // leave it alone, already sufficiently locked
LockerGuard::Empty => {
unreachable!("LockerGuard found empty");
} }
a => a,
}; };
return; match choice {
Choice::Return => return,
Choice::Break => break,
Choice::Continue => continue,
}
} }
} }
locks.push(( locks.push((
JsonPointer::to_owned(ptr.clone()), JsonPointer::to_owned(ptr.clone()),
LockerGuard::Read(self.lock_read(ptr).await), if let Some(lock) = final_lock {
lock
} else {
LockerGuard::Write(self.lock_write(ptr).await.into())
},
)); ));
} }
} }