From cb8998a79c32f54baa878149226157e987b92bd0 Mon Sep 17 00:00:00 2001 From: Aiden McClelland <3732071+dr-bonez@users.noreply.github.com> Date: Wed, 26 Oct 2022 14:19:26 -0600 Subject: [PATCH] make `Store::apply` cancel safe (#53) --- patch-db/src/store.rs | 75 +++++++++++++++++++++++++++---------------- 1 file changed, 48 insertions(+), 27 deletions(-) diff --git a/patch-db/src/store.rs b/patch-db/src/store.rs index db0d341..d236d5f 100644 --- a/patch-db/src/store.rs +++ b/patch-db/src/store.rs @@ -6,6 +6,7 @@ use std::sync::atomic::AtomicU64; use std::sync::Arc; use fd_lock_rs::FdLock; +use json_patch::PatchError; use json_ptr::{JsonPointer, SegList}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; @@ -221,43 +222,63 @@ impl Store { pub(crate) async fn apply(&mut self, patch: DiffPatch) -> Result>, Error> { use tokio::io::AsyncWriteExt; + // eject if noop if (patch.0).0.is_empty() { return Ok(None); } + struct TentativeUpdated<'a> { + store: &'a mut Store, + undo: Option>, + } + impl<'a> TentativeUpdated<'a> { + fn new(store: &'a mut Store, patch: &'a DiffPatch) -> Result { + let undo = json_patch::patch(&mut store.persistent, &*patch)?; + store.revision += 1; + Ok(Self { + store, + undo: Some(undo), + }) + } + } + impl<'a> Drop for TentativeUpdated<'a> { + fn drop(&mut self) { + if let Some(undo) = self.undo.take() { + undo.apply(&mut self.store.persistent); + self.store.revision -= 1; + } + } + } + #[cfg(feature = "tracing")] tracing::trace!("Attempting to apply patch: {:?}", patch); + // apply patch in memory let patch_bin = serde_cbor::to_vec(&*patch)?; - let persistent_undo = json_patch::patch(&mut self.persistent, &*patch)?; - self.revision += 1; - if let Err(_e) = if self.revision % 4096 == 0 { - self.compress().await + let mut updated = TentativeUpdated::new(self, &patch)?; + + if updated.store.revision % 4096 == 0 { + updated.store.compress().await? } else { - async { - if self.file.stream_position().await? != self.file_cursor { - self.file.set_len(self.file_cursor).await?; - self.file.seek(SeekFrom::Start(self.file_cursor)).await?; - } - self.file.write_all(&patch_bin).await?; - self.file.flush().await?; - self.file.sync_all().await?; - self.file_cursor += patch_bin.len() as u64; - Ok::<_, Error>(()) + if updated.store.file.stream_position().await? != updated.store.file_cursor { + updated + .store + .file + .set_len(updated.store.file_cursor) + .await?; + updated + .store + .file + .seek(SeekFrom::Start(updated.store.file_cursor)) + .await?; } - .await - } { - #[cfg(feature = "tracing")] - tracing::error!("Error saving patch to disk: {}, attempting to compress", _e); - if let Err(e) = self.compress().await { - #[cfg(feature = "tracing")] - tracing::error!("Compression failed: {}", e); - persistent_undo.apply(&mut self.persistent); - self.revision -= 1; - return Err(e); - } - }; - drop(persistent_undo); + updated.store.file.write_all(&patch_bin).await?; + updated.store.file.flush().await?; + updated.store.file.sync_all().await?; + updated.store.file_cursor += patch_bin.len() as u64; + } + drop(updated.undo.take()); + drop(updated); let id = self.revision; let res = Arc::new(Revision { id, patch });