From 6a0e9d5c0aa5cd5cecf7e89b0f498b5c59343947 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Thu, 19 May 2022 11:00:59 -0600 Subject: [PATCH] refactor packer to async --- backend/Cargo.toml | 2 +- backend/src/action.rs | 2 - backend/src/s9pk/builder.rs | 84 ++++++++++++++++---------------- backend/src/s9pk/header.rs | 52 +++++++++++--------- backend/src/s9pk/mod.rs | 97 ++++++++++++++++++++++--------------- backend/src/util/io.rs | 11 +++++ backend/src/util/mod.rs | 40 +++++++++++---- backend/src/util/serde.rs | 15 ++++++ 8 files changed, 188 insertions(+), 115 deletions(-) diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 443915717..5943f980f 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -86,7 +86,7 @@ num_enum = "0.5.4" openssh-keys = "0.5.0" openssl = { version = "0.10.36", features = ["vendored"] } patch-db = { version = "*", path = "../patch-db/patch-db", features = [ - "trace" + "trace", ] } pbkdf2 = "0.9.0" pin-project = "1.0.8" diff --git a/backend/src/action.rs b/backend/src/action.rs index 7636da334..8a4ae3a77 100644 --- a/backend/src/action.rs +++ b/backend/src/action.rs @@ -1,12 +1,10 @@ use std::collections::{BTreeMap, BTreeSet}; use std::path::Path; use std::str::FromStr; -use std::time::Duration; use clap::ArgMatches; use color_eyre::eyre::eyre; use indexmap::IndexSet; -use patch_db::HasModel; use rpc_toolkit::command; use serde::{Deserialize, Serialize}; use tracing::instrument; diff --git a/backend/src/s9pk/builder.rs b/backend/src/s9pk/builder.rs index 8478215a7..a973b4276 100644 --- a/backend/src/s9pk/builder.rs +++ b/backend/src/s9pk/builder.rs @@ -1,26 +1,26 @@ -use std::io::{Read, Seek, SeekFrom, Write}; - use digest::Digest; use sha2::Sha512; +use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, SeekFrom}; use tracing::instrument; use typed_builder::TypedBuilder; use super::header::{FileSection, Header}; use super::manifest::Manifest; use super::SIG_CONTEXT; +use crate::util::io::to_cbor_async_writer; use crate::util::HashWriter; use crate::{Error, ResultExt}; #[derive(TypedBuilder)] pub struct S9pkPacker< 'a, - W: Write + Seek, - RLicense: Read, - RInstructions: Read, - RIcon: Read, - RDockerImages: Read, - RAssets: Read, - RScripts: Read, + W: AsyncWriteExt + AsyncSeekExt, + RLicense: AsyncReadExt + Unpin, + RInstructions: AsyncReadExt + Unpin, + RIcon: AsyncReadExt + Unpin, + RDockerImages: AsyncReadExt + Unpin, + RAssets: AsyncReadExt + Unpin, + RScripts: AsyncReadExt + Unpin, > { writer: W, manifest: &'a Manifest, @@ -33,85 +33,85 @@ pub struct S9pkPacker< } impl< 'a, - W: Write + Seek, - RLicense: Read, - RInstructions: Read, - RIcon: Read, - RDockerImages: Read, - RAssets: Read, - RScripts: Read, + W: AsyncWriteExt + AsyncSeekExt + Unpin, + RLicense: AsyncReadExt + Unpin, + RInstructions: AsyncReadExt + Unpin, + RIcon: AsyncReadExt + Unpin, + RDockerImages: AsyncReadExt + Unpin, + RAssets: AsyncReadExt + Unpin, + RScripts: AsyncReadExt + Unpin, > S9pkPacker<'a, W, RLicense, RInstructions, RIcon, RDockerImages, RAssets, RScripts> { /// BLOCKING #[instrument(skip(self))] - pub fn pack(mut self, key: &ed25519_dalek::Keypair) -> Result<(), Error> { - let header_pos = self.writer.stream_position()?; + pub async fn pack(mut self, key: &ed25519_dalek::Keypair) -> Result<(), Error> { + let header_pos = self.writer.stream_position().await?; if header_pos != 0 { tracing::warn!("Appending to non-empty file."); } let mut header = Header::placeholder(); - header.serialize(&mut self.writer).with_ctx(|_| { + header.serialize(&mut self.writer).await.with_ctx(|_| { ( crate::ErrorKind::Serialization, "Writing Placeholder Header", ) })?; - let mut position = self.writer.stream_position()?; + let mut position = self.writer.stream_position().await?; let mut writer = HashWriter::new(Sha512::new(), &mut self.writer); // manifest - serde_cbor::ser::into_writer(self.manifest, &mut writer).with_ctx(|_| { - ( - crate::ErrorKind::Serialization, - "Serializing Manifest (CBOR)", - ) - })?; - let new_pos = writer.inner_mut().stream_position()?; + to_cbor_async_writer(&mut writer, self.manifest).await?; + let new_pos = writer.inner_mut().stream_position().await?; header.table_of_contents.manifest = FileSection { position, length: new_pos - position, }; position = new_pos; // license - std::io::copy(&mut self.license, &mut writer) + tokio::io::copy(&mut self.license, &mut writer) + .await .with_ctx(|_| (crate::ErrorKind::Filesystem, "Copying License"))?; - let new_pos = writer.inner_mut().stream_position()?; + let new_pos = writer.inner_mut().stream_position().await?; header.table_of_contents.license = FileSection { position, length: new_pos - position, }; position = new_pos; // instructions - std::io::copy(&mut self.instructions, &mut writer) + tokio::io::copy(&mut self.instructions, &mut writer) + .await .with_ctx(|_| (crate::ErrorKind::Filesystem, "Copying Instructions"))?; - let new_pos = writer.inner_mut().stream_position()?; + let new_pos = writer.inner_mut().stream_position().await?; header.table_of_contents.instructions = FileSection { position, length: new_pos - position, }; position = new_pos; // icon - std::io::copy(&mut self.icon, &mut writer) + tokio::io::copy(&mut self.icon, &mut writer) + .await .with_ctx(|_| (crate::ErrorKind::Filesystem, "Copying Icon"))?; - let new_pos = writer.inner_mut().stream_position()?; + let new_pos = writer.inner_mut().stream_position().await?; header.table_of_contents.icon = FileSection { position, length: new_pos - position, }; position = new_pos; // docker_images - std::io::copy(&mut self.docker_images, &mut writer) + tokio::io::copy(&mut self.docker_images, &mut writer) + .await .with_ctx(|_| (crate::ErrorKind::Filesystem, "Copying Docker Images"))?; - let new_pos = writer.inner_mut().stream_position()?; + let new_pos = writer.inner_mut().stream_position().await?; header.table_of_contents.docker_images = FileSection { position, length: new_pos - position, }; position = new_pos; // assets - std::io::copy(&mut self.assets, &mut writer) + tokio::io::copy(&mut self.assets, &mut writer) + .await .with_ctx(|_| (crate::ErrorKind::Filesystem, "Copying Assets"))?; - let new_pos = writer.inner_mut().stream_position()?; + let new_pos = writer.inner_mut().stream_position().await?; header.table_of_contents.assets = FileSection { position, length: new_pos - position, @@ -119,9 +119,10 @@ impl< position = new_pos; // scripts if let Some(mut scripts) = self.scripts { - std::io::copy(&mut scripts, &mut writer) + tokio::io::copy(&mut scripts, &mut writer) + .await .with_ctx(|_| (crate::ErrorKind::Filesystem, "Copying Scripts"))?; - let new_pos = writer.inner_mut().stream_position()?; + let new_pos = writer.inner_mut().stream_position().await?; header.table_of_contents.scripts = Some(FileSection { position, length: new_pos - position, @@ -131,13 +132,14 @@ impl< // header let (hash, _) = writer.finish(); - self.writer.seek(SeekFrom::Start(header_pos))?; + self.writer.seek(SeekFrom::Start(header_pos)).await?; header.pubkey = key.public.clone(); header.signature = key.sign_prehashed(hash, Some(SIG_CONTEXT))?; header .serialize(&mut self.writer) + .await .with_ctx(|_| (crate::ErrorKind::Serialization, "Writing Header"))?; - self.writer.seek(SeekFrom::Start(position))?; + self.writer.seek(SeekFrom::Start(position)).await?; Ok(()) } diff --git a/backend/src/s9pk/header.rs b/backend/src/s9pk/header.rs index f07f22e14..9e8bfae7e 100644 --- a/backend/src/s9pk/header.rs +++ b/backend/src/s9pk/header.rs @@ -1,9 +1,8 @@ use std::collections::BTreeMap; -use std::io::Write; use color_eyre::eyre::eyre; use ed25519_dalek::{PublicKey, Signature}; -use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; use crate::Error; @@ -25,12 +24,12 @@ impl Header { } } // MUST BE SAME SIZE REGARDLESS OF DATA - pub fn serialize(&self, mut writer: W) -> std::io::Result<()> { - writer.write_all(&MAGIC)?; - writer.write_all(&[VERSION])?; - writer.write_all(self.pubkey.as_bytes())?; - writer.write_all(self.signature.as_ref())?; - self.table_of_contents.serialize(writer)?; + pub async fn serialize(&self, mut writer: W) -> std::io::Result<()> { + writer.write_all(&MAGIC).await?; + writer.write_all(&[VERSION]).await?; + writer.write_all(self.pubkey.as_bytes()).await?; + writer.write_all(self.signature.as_ref()).await?; + self.table_of_contents.serialize(writer).await?; Ok(()) } pub async fn deserialize(mut reader: R) -> Result { @@ -78,7 +77,7 @@ pub struct TableOfContents { pub scripts: Option, } impl TableOfContents { - pub fn serialize(&self, mut writer: W) -> std::io::Result<()> { + pub async fn serialize(&self, mut writer: W) -> std::io::Result<()> { let len: u32 = ((1 + "manifest".len() + 16) + (1 + "license".len() + 16) + (1 + "instructions".len() + 16) @@ -86,18 +85,23 @@ impl TableOfContents { + (1 + "docker_images".len() + 16) + (1 + "assets".len() + 16) + (1 + "scripts".len() + 16)) as u32; - writer.write_all(&u32::to_be_bytes(len))?; - self.manifest.serialize_entry("manifest", &mut writer)?; - self.license.serialize_entry("license", &mut writer)?; + writer.write_all(&u32::to_be_bytes(len)).await?; + self.manifest + .serialize_entry("manifest", &mut writer) + .await?; + self.license.serialize_entry("license", &mut writer).await?; self.instructions - .serialize_entry("instructions", &mut writer)?; - self.icon.serialize_entry("icon", &mut writer)?; + .serialize_entry("instructions", &mut writer) + .await?; + self.icon.serialize_entry("icon", &mut writer).await?; self.docker_images - .serialize_entry("docker_images", &mut writer)?; - self.assets.serialize_entry("assets", &mut writer)?; + .serialize_entry("docker_images", &mut writer) + .await?; + self.assets.serialize_entry("assets", &mut writer).await?; self.scripts .unwrap_or_default() - .serialize_entry("scripts", &mut writer)?; + .serialize_entry("scripts", &mut writer) + .await?; Ok(()) } pub async fn deserialize(mut reader: R) -> std::io::Result { @@ -147,11 +151,15 @@ pub struct FileSection { pub length: u64, } impl FileSection { - pub fn serialize_entry(self, label: &str, mut writer: W) -> std::io::Result<()> { - writer.write_all(&[label.len() as u8])?; - writer.write_all(label.as_bytes())?; - writer.write_all(&u64::to_be_bytes(self.position))?; - writer.write_all(&u64::to_be_bytes(self.length))?; + pub async fn serialize_entry( + self, + label: &str, + mut writer: W, + ) -> std::io::Result<()> { + writer.write_all(&[label.len() as u8]).await?; + writer.write_all(label.as_bytes()).await?; + writer.write_all(&u64::to_be_bytes(self.position)).await?; + writer.write_all(&u64::to_be_bytes(self.length)).await?; Ok(()) } pub async fn deserialize_entry( diff --git a/backend/src/s9pk/mod.rs b/backend/src/s9pk/mod.rs index 57601572e..3170926f3 100644 --- a/backend/src/s9pk/mod.rs +++ b/backend/src/s9pk/mod.rs @@ -22,10 +22,10 @@ pub mod reader; pub const SIG_CONTEXT: &'static [u8] = b"s9pk"; -#[command(cli_only, display(display_none), blocking)] +#[command(cli_only, display(display_none))] #[instrument(skip(ctx))] -pub fn pack(#[context] ctx: SdkContext, #[arg] path: Option) -> Result<(), Error> { - use std::fs::File; +pub async fn pack(#[context] ctx: SdkContext, #[arg] path: Option) -> Result<(), Error> { + use tokio::fs::File; let path = if let Some(path) = path { path @@ -33,11 +33,17 @@ pub fn pack(#[context] ctx: SdkContext, #[arg] path: Option) -> Result< std::env::current_dir()? }; let manifest_value: Value = if path.join("manifest.toml").exists() { - IoFormat::Toml.from_reader(File::open(path.join("manifest.toml"))?)? + IoFormat::Toml + .from_async_reader(File::open(path.join("manifest.toml")).await?) + .await? } else if path.join("manifest.yaml").exists() { - IoFormat::Yaml.from_reader(File::open(path.join("manifest.yaml"))?)? + IoFormat::Yaml + .from_async_reader(File::open(path.join("manifest.yaml")).await?) + .await? } else if path.join("manifest.json").exists() { - IoFormat::Json.from_reader(File::open(path.join("manifest.json"))?)? + IoFormat::Json + .from_async_reader(File::open(path.join("manifest.json")).await?) + .await? } else { return Err(Error::new( eyre!("manifest not found"), @@ -53,69 +59,80 @@ pub fn pack(#[context] ctx: SdkContext, #[arg] path: Option) -> Result< } let outfile_path = path.join(format!("{}.s9pk", manifest.id)); - let mut outfile = File::create(outfile_path)?; + let mut outfile = File::create(outfile_path).await?; S9pkPacker::builder() .manifest(&manifest) .writer(&mut outfile) .license( - File::open(path.join(manifest.assets.license_path())).with_ctx(|_| { - ( - crate::ErrorKind::Filesystem, - manifest.assets.license_path().display().to_string(), - ) - })?, + File::open(path.join(manifest.assets.license_path())) + .await + .with_ctx(|_| { + ( + crate::ErrorKind::Filesystem, + manifest.assets.license_path().display().to_string(), + ) + })?, ) .icon( - File::open(path.join(manifest.assets.icon_path())).with_ctx(|_| { - ( - crate::ErrorKind::Filesystem, - manifest.assets.icon_path().display().to_string(), - ) - })?, + File::open(path.join(manifest.assets.icon_path())) + .await + .with_ctx(|_| { + ( + crate::ErrorKind::Filesystem, + manifest.assets.icon_path().display().to_string(), + ) + })?, ) .instructions( - File::open(path.join(manifest.assets.instructions_path())).with_ctx(|_| { - ( - crate::ErrorKind::Filesystem, - manifest.assets.instructions_path().display().to_string(), - ) - })?, + File::open(path.join(manifest.assets.instructions_path())) + .await + .with_ctx(|_| { + ( + crate::ErrorKind::Filesystem, + manifest.assets.instructions_path().display().to_string(), + ) + })?, ) .docker_images( - File::open(path.join(manifest.assets.docker_images_path())).with_ctx(|_| { - ( - crate::ErrorKind::Filesystem, - manifest.assets.docker_images_path().display().to_string(), - ) - })?, + File::open(path.join(manifest.assets.docker_images_path())) + .await + .with_ctx(|_| { + ( + crate::ErrorKind::Filesystem, + manifest.assets.docker_images_path().display().to_string(), + ) + })?, ) .assets({ - let mut assets = tar::Builder::new(Vec::new()); // TODO: Ideally stream this? best not to buffer in memory + let mut assets = tokio_tar::Builder::new(Vec::new()); // TODO: Ideally stream this? best not to buffer in memory for (asset_volume, _) in manifest .volumes .iter() .filter(|(_, v)| matches!(v, &&Volume::Assets {})) { - assets.append_dir_all( - asset_volume, - path.join(manifest.assets.assets_path()).join(asset_volume), - )?; + assets + .append_dir_all( + asset_volume, + path.join(manifest.assets.assets_path()).join(asset_volume), + ) + .await?; } - std::io::Cursor::new(assets.into_inner()?) + std::io::Cursor::new(assets.into_inner().await?) }) .scripts({ let script_path = path.join(manifest.assets.scripts_path()).join("embassy.js"); if script_path.exists() { - Some(File::open(script_path)?) + Some(File::open(script_path).await?) } else { None } }) .build() - .pack(&ctx.developer_key()?)?; - outfile.sync_all()?; + .pack(&ctx.developer_key()?) + .await?; + outfile.sync_all().await?; Ok(()) } diff --git a/backend/src/util/io.rs b/backend/src/util/io.rs index a3ef0bc98..8953b1c51 100644 --- a/backend/src/util/io.rs +++ b/backend/src/util/io.rs @@ -153,6 +153,17 @@ where .map_err(color_eyre::eyre::Error::from) .with_kind(crate::ErrorKind::Deserialization) } +pub async fn to_cbor_async_writer(mut writer: W, value: &T) -> Result<(), crate::Error> +where + T: serde::Serialize, + W: AsyncWrite + Unpin, +{ + let mut buffer = Vec::new(); + serde_cbor::ser::into_writer(value, &mut buffer).with_kind(crate::ErrorKind::Serialization)?; + buffer.extend_from_slice(b"\n"); + writer.write_all(&buffer).await?; + Ok(()) +} pub async fn from_json_async_reader(mut reader: R) -> Result where diff --git a/backend/src/util/mod.rs b/backend/src/util/mod.rs index 9646b9f17..1644ed564 100644 --- a/backend/src/util/mod.rs +++ b/backend/src/util/mod.rs @@ -4,9 +4,11 @@ use std::hash::{Hash, Hasher}; use std::marker::PhantomData; use std::ops::Deref; use std::path::{Path, PathBuf}; +use std::pin::Pin; use std::process::Stdio; use std::str::FromStr; use std::sync::Arc; +use std::task::{Context, Poll}; use ::serde::{Deserialize, Deserializer, Serialize, Serializer}; use async_trait::async_trait; @@ -18,6 +20,7 @@ use futures::future::BoxFuture; use futures::FutureExt; use lazy_static::lazy_static; use patch_db::{HasModel, Model}; +use pin_project::pin_project; use tokio::fs::File; use tokio::sync::{Mutex, OwnedMutexGuard, RwLock}; use tokio::task::{JoinError, JoinHandle}; @@ -292,11 +295,13 @@ impl Container { } } -pub struct HashWriter { +#[pin_project] +pub struct HashWriter { hasher: H, + #[pin] writer: W, } -impl HashWriter { +impl HashWriter { pub fn new(hasher: H, writer: W) -> Self { HashWriter { hasher, writer } } @@ -310,14 +315,31 @@ impl HashWriter { &mut self.writer } } -impl std::io::Write for HashWriter { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - let written = self.writer.write(buf)?; - self.hasher.update(&buf[..written]); - Ok(written) +impl tokio::io::AsyncWrite for HashWriter { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + let this = self.project(); + let written = tokio::io::AsyncWrite::poll_write(this.writer, cx, &buf); + match written { + // only update the hasher once + Poll::Ready(res) => { + if let Ok(n) = res { + this.hasher.update(&buf[..n]); + } + Poll::Ready(res) + } + Poll::Pending => Poll::Pending, + } } - fn flush(&mut self) -> std::io::Result<()> { - self.writer.flush() + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().writer.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.project().writer.poll_shutdown(cx) } } diff --git a/backend/src/util/serde.rs b/backend/src/util/serde.rs index 59b78be91..a9b5cb654 100644 --- a/backend/src/util/serde.rs +++ b/backend/src/util/serde.rs @@ -379,6 +379,21 @@ impl IoFormat { } } } + pub async fn from_async_reader< + R: tokio::io::AsyncRead + Unpin, + T: for<'de> Deserialize<'de>, + >( + &self, + reader: R, + ) -> Result { + use crate::util::io::*; + match self { + IoFormat::Json | IoFormat::JsonPretty => from_json_async_reader(reader).await, + IoFormat::Yaml => from_yaml_async_reader(reader).await, + IoFormat::Cbor => from_cbor_async_reader(reader).await, + IoFormat::Toml | IoFormat::TomlPretty => from_toml_async_reader(reader).await, + } + } pub fn from_slice Deserialize<'de>>(&self, slice: &[u8]) -> Result { match self { IoFormat::Json | IoFormat::JsonPretty => {