use std::future::Future; use std::hash::{Hash, Hasher}; use std::marker::PhantomData; use std::ops::Deref; use std::path::Path; use std::process::{exit, Stdio}; use std::str::FromStr; use std::time::Duration; use async_trait::async_trait; use clap::ArgMatches; use color_eyre::eyre::{self, eyre}; use digest::Digest; use patch_db::{HasModel, Model}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::Value; use tokio::fs::File; use tokio::sync::RwLock; use tokio::task::{JoinError, JoinHandle}; use crate::shutdown::Shutdown; use crate::{Error, ResultExt as _}; pub mod io; pub mod logger; #[derive(Clone, Copy, Debug)] pub enum Never {} impl Never {} impl Never { pub fn absurd(self) -> T { match self {} } } impl std::fmt::Display for Never { fn fmt(&self, _f: &mut std::fmt::Formatter) -> std::fmt::Result { self.absurd() } } impl std::error::Error for Never {} #[async_trait::async_trait] pub trait Invoke { async fn invoke(&mut self, error_kind: crate::ErrorKind) -> Result, Error>; } #[async_trait::async_trait] impl Invoke for tokio::process::Command { async fn invoke(&mut self, error_kind: crate::ErrorKind) -> Result, Error> { self.stdout(Stdio::piped()); self.stderr(Stdio::piped()); let res = self.output().await?; crate::ensure_code!( res.status.success(), error_kind, "{}", std::str::from_utf8(&res.stderr).unwrap_or("Unknown Error") ); Ok(res.stdout) } } pub trait Apply: Sized { fn apply O>(self, func: F) -> O { func(self) } } pub trait ApplyRef { fn apply_ref O>(&self, func: F) -> O { func(&self) } fn apply_mut O>(&mut self, func: F) -> O { func(self) } } impl Apply for T {} impl ApplyRef for T {} pub fn deserialize_from_str< 'de, D: serde::de::Deserializer<'de>, T: FromStr, E: std::fmt::Display, >( deserializer: D, ) -> std::result::Result { struct Visitor, E>(std::marker::PhantomData); impl<'de, T: FromStr, Err: std::fmt::Display> serde::de::Visitor<'de> for Visitor { type Value = T; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(formatter, "a parsable string") } fn visit_str(self, v: &str) -> Result where E: serde::de::Error, { v.parse().map_err(|e| serde::de::Error::custom(e)) } } deserializer.deserialize_str(Visitor(std::marker::PhantomData)) } pub fn deserialize_from_str_opt< 'de, D: serde::de::Deserializer<'de>, T: FromStr, E: std::fmt::Display, >( deserializer: D, ) -> std::result::Result, D::Error> { struct Visitor, E>(std::marker::PhantomData); impl<'de, T: FromStr, Err: std::fmt::Display> serde::de::Visitor<'de> for Visitor { type Value = Option; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(formatter, "a parsable string") } fn visit_str(self, v: &str) -> Result where E: serde::de::Error, { v.parse().map(Some).map_err(|e| serde::de::Error::custom(e)) } fn visit_some(self, deserializer: D) -> Result where D: serde::de::Deserializer<'de>, { deserializer.deserialize_str(Visitor(std::marker::PhantomData)) } fn visit_none(self) -> Result where E: serde::de::Error, { Ok(None) } fn visit_unit(self) -> Result where E: serde::de::Error, { Ok(None) } } deserializer.deserialize_any(Visitor(std::marker::PhantomData)) } pub fn serialize_display( t: &T, serializer: S, ) -> Result { String::serialize(&t.to_string(), serializer) } pub fn serialize_display_opt( t: &Option, serializer: S, ) -> Result { Option::::serialize(&t.as_ref().map(|t| t.to_string()), serializer) } pub async fn daemon Fut, Fut: Future + Send + 'static>( mut f: F, cooldown: std::time::Duration, mut shutdown: tokio::sync::broadcast::Receiver>, ) -> Result<(), eyre::Error> { loop { tokio::select! { _ = shutdown.recv() => return Ok(()), _ = tokio::time::sleep(cooldown) => (), } match tokio::spawn(f()).await { Err(e) if e.is_panic() => return Err(eyre!("daemon panicked!")), _ => (), } } } pub trait SOption {} pub struct SSome(T); impl SSome { pub fn into(self) -> T { self.0 } } impl From for SSome { fn from(t: T) -> Self { SSome(t) } } impl SOption for SSome {} pub struct SNone(PhantomData); impl SNone { pub fn new() -> Self { SNone(PhantomData) } } impl SOption for SNone {} #[derive(Debug, Serialize)] #[serde(untagged)] pub enum ValuePrimative { Null, Boolean(bool), String(String), Number(serde_json::Number), } impl<'de> serde::de::Deserialize<'de> for ValuePrimative { fn deserialize(deserializer: D) -> Result where D: serde::de::Deserializer<'de>, { struct Visitor; impl<'de> serde::de::Visitor<'de> for Visitor { type Value = ValuePrimative; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { write!(formatter, "a JSON primative value") } fn visit_unit(self) -> Result where E: serde::de::Error, { Ok(ValuePrimative::Null) } fn visit_none(self) -> Result where E: serde::de::Error, { Ok(ValuePrimative::Null) } fn visit_bool(self, v: bool) -> Result where E: serde::de::Error, { Ok(ValuePrimative::Boolean(v)) } fn visit_str(self, v: &str) -> Result where E: serde::de::Error, { Ok(ValuePrimative::String(v.to_owned())) } fn visit_string(self, v: String) -> Result where E: serde::de::Error, { Ok(ValuePrimative::String(v)) } fn visit_f32(self, v: f32) -> Result where E: serde::de::Error, { Ok(ValuePrimative::Number( serde_json::Number::from_f64(v as f64).ok_or_else(|| { serde::de::Error::invalid_value( serde::de::Unexpected::Float(v as f64), &"a finite number", ) })?, )) } fn visit_f64(self, v: f64) -> Result where E: serde::de::Error, { Ok(ValuePrimative::Number( serde_json::Number::from_f64(v).ok_or_else(|| { serde::de::Error::invalid_value( serde::de::Unexpected::Float(v), &"a finite number", ) })?, )) } fn visit_u8(self, v: u8) -> Result where E: serde::de::Error, { Ok(ValuePrimative::Number(v.into())) } fn visit_u16(self, v: u16) -> Result where E: serde::de::Error, { Ok(ValuePrimative::Number(v.into())) } fn visit_u32(self, v: u32) -> Result where E: serde::de::Error, { Ok(ValuePrimative::Number(v.into())) } fn visit_u64(self, v: u64) -> Result where E: serde::de::Error, { Ok(ValuePrimative::Number(v.into())) } fn visit_i8(self, v: i8) -> Result where E: serde::de::Error, { Ok(ValuePrimative::Number(v.into())) } fn visit_i16(self, v: i16) -> Result where E: serde::de::Error, { Ok(ValuePrimative::Number(v.into())) } fn visit_i32(self, v: i32) -> Result where E: serde::de::Error, { Ok(ValuePrimative::Number(v.into())) } fn visit_i64(self, v: i64) -> Result where E: serde::de::Error, { Ok(ValuePrimative::Number(v.into())) } } deserializer.deserialize_any(Visitor) } } #[derive(Debug, Clone)] pub struct Version { version: emver::Version, string: String, } impl Version { pub fn as_str(&self) -> &str { self.string.as_str() } } impl std::fmt::Display for Version { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.string) } } impl std::str::FromStr for Version { type Err = ::Err; fn from_str(s: &str) -> Result { Ok(Version { string: s.to_owned(), version: s.parse()?, }) } } impl From for Version { fn from(v: emver::Version) -> Self { Version { string: v.to_string(), version: v, } } } impl From for emver::Version { fn from(v: Version) -> Self { v.version } } impl Default for Version { fn default() -> Self { Self::from(emver::Version::default()) } } impl Deref for Version { type Target = emver::Version; fn deref(&self) -> &Self::Target { &self.version } } impl AsRef for Version { fn as_ref(&self) -> &emver::Version { &self.version } } impl AsRef for Version { fn as_ref(&self) -> &str { self.as_str() } } impl PartialEq for Version { fn eq(&self, other: &Version) -> bool { self.version.eq(&other.version) } } impl Eq for Version {} impl PartialOrd for Version { fn partial_cmp(&self, other: &Self) -> Option { self.version.partial_cmp(&other.version) } } impl Ord for Version { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.version.cmp(&other.version) } } impl Hash for Version { fn hash(&self, state: &mut H) { self.version.hash(state) } } impl<'de> Deserialize<'de> for Version { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { let string = String::deserialize(deserializer)?; let version = emver::Version::from_str(&string).map_err(serde::de::Error::custom)?; Ok(Self { string, version }) } } impl Serialize for Version { fn serialize(&self, serializer: S) -> Result where S: Serializer, { self.string.serialize(serializer) } } impl HasModel for Version { type Model = Model; } #[async_trait] pub trait AsyncFileExt: Sized { async fn maybe_open + Send + Sync>(path: P) -> std::io::Result>; async fn delete + Send + Sync>(path: P) -> std::io::Result<()>; } #[async_trait] impl AsyncFileExt for File { async fn maybe_open + Send + Sync>(path: P) -> std::io::Result> { match File::open(path).await { Ok(f) => Ok(Some(f)), Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None), Err(e) => Err(e), } } async fn delete + Send + Sync>(path: P) -> std::io::Result<()> { if let Ok(m) = tokio::fs::metadata(path.as_ref()).await { if m.is_dir() { tokio::fs::remove_dir_all(path).await } else { tokio::fs::remove_file(path).await } } else { Ok(()) } } } pub struct FmtWriter(W); impl std::io::Write for FmtWriter { fn write(&mut self, buf: &[u8]) -> std::io::Result { self.0 .write_str( std::str::from_utf8(buf) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?, ) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; Ok(buf.len()) } fn flush(&mut self) -> std::io::Result<()> { Ok(()) } } #[derive(Clone, Copy, Debug, Deserialize, Serialize)] #[serde(rename_all = "kebab-case")] pub enum IoFormat { Json, JsonPretty, Yaml, Cbor, Toml, TomlPretty, } impl Default for IoFormat { fn default() -> Self { IoFormat::JsonPretty } } impl std::fmt::Display for IoFormat { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use IoFormat::*; match self { Json => write!(f, "JSON"), JsonPretty => write!(f, "JSON (pretty)"), Yaml => write!(f, "YAML"), Cbor => write!(f, "CBOR"), Toml => write!(f, "TOML"), TomlPretty => write!(f, "TOML (pretty)"), } } } impl std::str::FromStr for IoFormat { type Err = Error; fn from_str(s: &str) -> Result { serde_json::from_value(Value::String(s.to_owned())) .with_kind(crate::ErrorKind::Deserialization) } } impl IoFormat { pub fn to_writer( &self, mut writer: W, value: &T, ) -> Result<(), Error> { match self { IoFormat::Json => { serde_json::to_writer(writer, value).with_kind(crate::ErrorKind::Serialization) } IoFormat::JsonPretty => serde_json::to_writer_pretty(writer, value) .with_kind(crate::ErrorKind::Serialization), IoFormat::Yaml => { serde_yaml::to_writer(writer, value).with_kind(crate::ErrorKind::Serialization) } IoFormat::Cbor => serde_cbor::ser::into_writer(value, writer) .with_kind(crate::ErrorKind::Serialization), IoFormat::Toml => writer .write_all( &serde_toml::to_vec( &serde_toml::Value::try_from(value) .with_kind(crate::ErrorKind::Serialization)?, ) .with_kind(crate::ErrorKind::Serialization)?, ) .with_kind(crate::ErrorKind::Serialization), IoFormat::TomlPretty => writer .write_all( serde_toml::to_string_pretty( &serde_toml::Value::try_from(value) .with_kind(crate::ErrorKind::Serialization)?, ) .with_kind(crate::ErrorKind::Serialization)? .as_bytes(), ) .with_kind(crate::ErrorKind::Serialization), } } pub fn to_vec(&self, value: &T) -> Result, Error> { match self { IoFormat::Json => serde_json::to_vec(value).with_kind(crate::ErrorKind::Serialization), IoFormat::JsonPretty => { serde_json::to_vec_pretty(value).with_kind(crate::ErrorKind::Serialization) } IoFormat::Yaml => serde_yaml::to_vec(value).with_kind(crate::ErrorKind::Serialization), IoFormat::Cbor => { let mut res = Vec::new(); serde_cbor::ser::into_writer(value, &mut res) .with_kind(crate::ErrorKind::Serialization)?; Ok(res) } IoFormat::Toml => serde_toml::to_vec( &serde_toml::Value::try_from(value).with_kind(crate::ErrorKind::Serialization)?, ) .with_kind(crate::ErrorKind::Serialization), IoFormat::TomlPretty => serde_toml::to_string_pretty( &serde_toml::Value::try_from(value).with_kind(crate::ErrorKind::Serialization)?, ) .map(|s| s.into_bytes()) .with_kind(crate::ErrorKind::Serialization), } } /// BLOCKING pub fn from_reader Deserialize<'de>>( &self, mut reader: R, ) -> Result { match self { IoFormat::Json | IoFormat::JsonPretty => { serde_json::from_reader(reader).with_kind(crate::ErrorKind::Deserialization) } IoFormat::Yaml => { serde_yaml::from_reader(reader).with_kind(crate::ErrorKind::Deserialization) } IoFormat::Cbor => { serde_cbor::de::from_reader(reader).with_kind(crate::ErrorKind::Deserialization) } IoFormat::Toml | IoFormat::TomlPretty => { let mut s = String::new(); reader .read_to_string(&mut s) .with_kind(crate::ErrorKind::Deserialization)?; serde_toml::from_str(&s).with_kind(crate::ErrorKind::Deserialization) } } } pub fn from_slice Deserialize<'de>>(&self, slice: &[u8]) -> Result { match self { IoFormat::Json | IoFormat::JsonPretty => { serde_json::from_slice(slice).with_kind(crate::ErrorKind::Deserialization) } IoFormat::Yaml => { serde_yaml::from_slice(slice).with_kind(crate::ErrorKind::Deserialization) } IoFormat::Cbor => { serde_cbor::de::from_reader(slice).with_kind(crate::ErrorKind::Deserialization) } IoFormat::Toml | IoFormat::TomlPretty => { serde_toml::from_slice(slice).with_kind(crate::ErrorKind::Deserialization) } } } } pub fn display_serializable(t: T, matches: &ArgMatches<'_>) { let format = match matches.value_of("format").map(|f| f.parse()) { Some(Ok(f)) => f, Some(Err(_)) => { eprintln!("unrecognized formatter"); exit(1) } None => IoFormat::default(), }; format .to_writer(std::io::stdout(), &t) .expect("Error serializing result to stdout") } pub fn display_none(_: T, _: &ArgMatches) { () } pub fn parse_stdin_deserializable Deserialize<'de>>( stdin: &mut std::io::Stdin, matches: &ArgMatches<'_>, ) -> Result { let format = match matches.value_of("format").map(|f| f.parse()) { Some(Ok(f)) => f, Some(Err(_)) => { eprintln!("unrecognized formatter"); exit(1) } None => IoFormat::default(), }; format.from_reader(stdin) } pub fn parse_duration(arg: &str, _: &ArgMatches<'_>) -> Result { let units_idx = arg.find(|c: char| c.is_alphabetic()).ok_or_else(|| { Error::new( eyre!("Must specify units for duration"), crate::ErrorKind::Deserialization, ) })?; let (num, units) = arg.split_at(units_idx); match units { "d" if num.contains(".") => Ok(Duration::from_secs_f64(num.parse::()? * 86400_f64)), "d" => Ok(Duration::from_secs(num.parse::()? * 86400)), "h" if num.contains(".") => Ok(Duration::from_secs_f64(num.parse::()? * 3600_f64)), "h" => Ok(Duration::from_secs(num.parse::()? * 3600)), "m" if num.contains(".") => Ok(Duration::from_secs_f64(num.parse::()? * 60_f64)), "m" => Ok(Duration::from_secs(num.parse::()? * 60)), "s" if num.contains(".") => Ok(Duration::from_secs_f64(num.parse()?)), "s" => Ok(Duration::from_secs(num.parse()?)), "ms" => Ok(Duration::from_millis(num.parse()?)), "us" => Ok(Duration::from_micros(num.parse()?)), "ns" => Ok(Duration::from_nanos(num.parse()?)), _ => Err(Error::new( eyre!("Invalid units for duration"), crate::ErrorKind::Deserialization, )), } } pub struct Container(RwLock>); impl Container { pub fn new(value: Option) -> Self { Container(RwLock::new(value)) } pub async fn set(&self, value: T) -> Option { std::mem::replace(&mut *self.0.write().await, Some(value)) } pub async fn take(&self) -> Option { std::mem::replace(&mut *self.0.write().await, None) } pub async fn is_empty(&self) -> bool { self.0.read().await.is_none() } pub async fn drop(&self) { *self.0.write().await = None; } } pub struct HashWriter { hasher: H, writer: W, } impl HashWriter { pub fn new(hasher: H, writer: W) -> Self { HashWriter { hasher, writer } } pub fn finish(self) -> (H, W) { (self.hasher, self.writer) } pub fn inner(&self) -> &W { &self.writer } pub fn inner_mut(&mut self) -> &mut W { &mut self.writer } } impl std::io::Write for HashWriter { fn write(&mut self, buf: &[u8]) -> std::io::Result { let written = self.writer.write(buf)?; self.hasher.update(&buf[..written]); Ok(written) } fn flush(&mut self) -> std::io::Result<()> { self.writer.flush() } } pub fn deserialize_number_permissive< 'de, D: serde::de::Deserializer<'de>, T: FromStr + num::cast::FromPrimitive, E: std::fmt::Display, >( deserializer: D, ) -> std::result::Result { use num::cast::FromPrimitive; struct Visitor + num::cast::FromPrimitive, E>(std::marker::PhantomData); impl<'de, T: FromStr + num::cast::FromPrimitive, Err: std::fmt::Display> serde::de::Visitor<'de> for Visitor { type Value = T; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(formatter, "a parsable string") } fn visit_str(self, v: &str) -> Result where E: serde::de::Error, { v.parse().map_err(|e| serde::de::Error::custom(e)) } fn visit_f64(self, v: f64) -> Result where E: serde::de::Error, { T::from_f64(v).ok_or_else(|| { serde::de::Error::custom(format!( "{} cannot be represented by the requested type", v )) }) } fn visit_u64(self, v: u64) -> Result where E: serde::de::Error, { T::from_u64(v).ok_or_else(|| { serde::de::Error::custom(format!( "{} cannot be represented by the requested type", v )) }) } fn visit_i64(self, v: i64) -> Result where E: serde::de::Error, { T::from_i64(v).ok_or_else(|| { serde::de::Error::custom(format!( "{} cannot be represented by the requested type", v )) }) } } deserializer.deserialize_str(Visitor(std::marker::PhantomData)) } #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Port(pub u16); impl<'de> Deserialize<'de> for Port { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { //TODO: if number, be permissive deserialize_number_permissive(deserializer).map(Port) } } impl Serialize for Port { fn serialize(&self, serializer: S) -> Result where S: Serializer, { serialize_display(&self.0, serializer) } } pub trait IntoDoubleEndedIterator: IntoIterator { type IntoIter: Iterator + DoubleEndedIterator; fn into_iter(self) -> >::IntoIter; } impl IntoDoubleEndedIterator for T where T: IntoIterator, ::IntoIter: DoubleEndedIterator, { type IntoIter = ::IntoIter; fn into_iter(self) -> >::IntoIter { IntoIterator::into_iter(self) } } #[derive(Debug, Clone)] pub struct Reversible> where for<'a> &'a Container: IntoDoubleEndedIterator<&'a T>, { reversed: bool, data: Container, phantom: PhantomData, } impl Reversible where for<'a> &'a Container: IntoDoubleEndedIterator<&'a T>, { pub fn new(data: Container) -> Self { Reversible { reversed: false, data, phantom: PhantomData, } } pub fn reverse(&mut self) { self.reversed = !self.reversed } pub fn iter( &self, ) -> itertools::Either< <&Container as IntoDoubleEndedIterator<&T>>::IntoIter, std::iter::Rev<<&Container as IntoDoubleEndedIterator<&T>>::IntoIter>, > { let iter = IntoDoubleEndedIterator::into_iter(&self.data); if self.reversed { itertools::Either::Right(iter.rev()) } else { itertools::Either::Left(iter) } } } impl Serialize for Reversible where for<'a> &'a Container: IntoDoubleEndedIterator<&'a T>, T: Serialize, { fn serialize(&self, serializer: S) -> Result where S: Serializer, { use serde::ser::SerializeSeq; let iter = IntoDoubleEndedIterator::into_iter(&self.data); let mut seq_ser = serializer.serialize_seq(match iter.size_hint() { (lower, Some(upper)) if lower == upper => Some(upper), _ => None, })?; if self.reversed { for elem in iter.rev() { seq_ser.serialize_element(elem)?; } } else { for elem in iter { seq_ser.serialize_element(elem)?; } } seq_ser.end() } } impl<'de, T, Container> Deserialize<'de> for Reversible where for<'a> &'a Container: IntoDoubleEndedIterator<&'a T>, Container: Deserialize<'de>, { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { Ok(Reversible::new(Deserialize::deserialize(deserializer)?)) } fn deserialize_in_place(deserializer: D, place: &mut Self) -> Result<(), D::Error> where D: Deserializer<'de>, { Deserialize::deserialize_in_place(deserializer, &mut place.data) } } #[pin_project::pin_project(PinnedDrop)] pub struct NonDetachingJoinHandle(#[pin] JoinHandle); impl From> for NonDetachingJoinHandle { fn from(t: JoinHandle) -> Self { NonDetachingJoinHandle(t) } } #[pin_project::pinned_drop] impl PinnedDrop for NonDetachingJoinHandle { fn drop(self: std::pin::Pin<&mut Self>) { let this = self.project(); this.0.into_ref().get_ref().abort() } } impl Future for NonDetachingJoinHandle { type Output = Result; fn poll( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll { let this = self.project(); this.0.poll(cx) } } pub struct GeneralGuard T, T = ()>(Option); impl T, T> GeneralGuard { pub fn new(f: F) -> Self { GeneralGuard(Some(f)) } pub fn drop(mut self) -> T { self.0.take().unwrap()() } } impl T, T> Drop for GeneralGuard { fn drop(&mut self) { if let Some(destroy) = self.0.take() { destroy(); } } }