mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-03-26 10:21:52 +00:00
1034 lines
32 KiB
Rust
1034 lines
32 KiB
Rust
use std::future::Future;
|
|
use std::hash::{Hash, Hasher};
|
|
use std::marker::PhantomData;
|
|
use std::ops::Deref;
|
|
use std::path::Path;
|
|
use std::process::{exit, Stdio};
|
|
use std::str::FromStr;
|
|
use std::time::Duration;
|
|
|
|
use anyhow::anyhow;
|
|
use async_trait::async_trait;
|
|
use clap::ArgMatches;
|
|
use digest::Digest;
|
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
|
use serde_json::Value;
|
|
use tokio::fs::File;
|
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
|
use tokio::sync::RwLock;
|
|
use tokio::task::{JoinError, JoinHandle};
|
|
|
|
use crate::shutdown::Shutdown;
|
|
use crate::{Error, ResultExt as _};
|
|
|
|
#[derive(Clone, Copy, Debug)]
|
|
pub enum Never {}
|
|
impl Never {}
|
|
impl Never {
|
|
pub fn absurd<T>(self) -> T {
|
|
match self {}
|
|
}
|
|
}
|
|
impl std::fmt::Display for Never {
|
|
fn fmt(&self, _f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
self.absurd()
|
|
}
|
|
}
|
|
impl std::error::Error for Never {}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct AsyncCompat<T>(pub T);
|
|
impl<T> futures::io::AsyncRead for AsyncCompat<T>
|
|
where
|
|
T: tokio::io::AsyncRead,
|
|
{
|
|
fn poll_read(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
buf: &mut [u8],
|
|
) -> std::task::Poll<std::io::Result<usize>> {
|
|
let mut read_buf = ReadBuf::new(buf);
|
|
tokio::io::AsyncRead::poll_read(
|
|
unsafe { self.map_unchecked_mut(|a| &mut a.0) },
|
|
cx,
|
|
&mut read_buf,
|
|
)
|
|
.map(|res| res.map(|_| read_buf.filled().len()))
|
|
}
|
|
}
|
|
impl<T> tokio::io::AsyncRead for AsyncCompat<T>
|
|
where
|
|
T: futures::io::AsyncRead,
|
|
{
|
|
fn poll_read(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
buf: &mut ReadBuf,
|
|
) -> std::task::Poll<std::io::Result<()>> {
|
|
futures::io::AsyncRead::poll_read(
|
|
unsafe { self.map_unchecked_mut(|a| &mut a.0) },
|
|
cx,
|
|
buf.initialize_unfilled(),
|
|
)
|
|
.map(|res| res.map(|len| buf.set_filled(len)))
|
|
}
|
|
}
|
|
impl<T> futures::io::AsyncWrite for AsyncCompat<T>
|
|
where
|
|
T: tokio::io::AsyncWrite,
|
|
{
|
|
fn poll_write(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
buf: &[u8],
|
|
) -> std::task::Poll<std::io::Result<usize>> {
|
|
tokio::io::AsyncWrite::poll_write(unsafe { self.map_unchecked_mut(|a| &mut a.0) }, cx, buf)
|
|
}
|
|
fn poll_flush(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<std::io::Result<()>> {
|
|
tokio::io::AsyncWrite::poll_flush(unsafe { self.map_unchecked_mut(|a| &mut a.0) }, cx)
|
|
}
|
|
fn poll_close(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<std::io::Result<()>> {
|
|
tokio::io::AsyncWrite::poll_shutdown(unsafe { self.map_unchecked_mut(|a| &mut a.0) }, cx)
|
|
}
|
|
}
|
|
impl<T> tokio::io::AsyncWrite for AsyncCompat<T>
|
|
where
|
|
T: futures::io::AsyncWrite,
|
|
{
|
|
fn poll_write(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
buf: &[u8],
|
|
) -> std::task::Poll<std::io::Result<usize>> {
|
|
futures::io::AsyncWrite::poll_write(
|
|
unsafe { self.map_unchecked_mut(|a| &mut a.0) },
|
|
cx,
|
|
buf,
|
|
)
|
|
}
|
|
fn poll_flush(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<std::io::Result<()>> {
|
|
futures::io::AsyncWrite::poll_flush(unsafe { self.map_unchecked_mut(|a| &mut a.0) }, cx)
|
|
}
|
|
fn poll_shutdown(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<std::io::Result<()>> {
|
|
futures::io::AsyncWrite::poll_close(unsafe { self.map_unchecked_mut(|a| &mut a.0) }, cx)
|
|
}
|
|
}
|
|
|
|
pub async fn from_yaml_async_reader<T, R>(mut reader: R) -> Result<T, crate::Error>
|
|
where
|
|
T: for<'de> serde::Deserialize<'de>,
|
|
R: AsyncRead + Unpin,
|
|
{
|
|
let mut buffer = Vec::new();
|
|
reader.read_to_end(&mut buffer).await?;
|
|
serde_yaml::from_slice(&buffer)
|
|
.map_err(anyhow::Error::from)
|
|
.with_kind(crate::ErrorKind::Deserialization)
|
|
}
|
|
|
|
pub async fn to_yaml_async_writer<T, W>(mut writer: W, value: &T) -> Result<(), crate::Error>
|
|
where
|
|
T: serde::Serialize,
|
|
W: AsyncWrite + Unpin,
|
|
{
|
|
let mut buffer = serde_yaml::to_vec(value).with_kind(crate::ErrorKind::Serialization)?;
|
|
buffer.extend_from_slice(b"\n");
|
|
writer.write_all(&buffer).await?;
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn from_toml_async_reader<T, R>(mut reader: R) -> Result<T, crate::Error>
|
|
where
|
|
T: for<'de> serde::Deserialize<'de>,
|
|
R: AsyncRead + Unpin,
|
|
{
|
|
let mut buffer = Vec::new();
|
|
reader.read_to_end(&mut buffer).await?;
|
|
serde_toml::from_slice(&buffer)
|
|
.map_err(anyhow::Error::from)
|
|
.with_kind(crate::ErrorKind::Deserialization)
|
|
}
|
|
|
|
pub async fn to_toml_async_writer<T, W>(mut writer: W, value: &T) -> Result<(), crate::Error>
|
|
where
|
|
T: serde::Serialize,
|
|
W: AsyncWrite + Unpin,
|
|
{
|
|
let mut buffer = serde_toml::to_vec(value).with_kind(crate::ErrorKind::Serialization)?;
|
|
buffer.extend_from_slice(b"\n");
|
|
writer.write_all(&buffer).await?;
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn from_cbor_async_reader<T, R>(mut reader: R) -> Result<T, crate::Error>
|
|
where
|
|
T: for<'de> serde::Deserialize<'de>,
|
|
R: AsyncRead + Unpin,
|
|
{
|
|
let mut buffer = Vec::new();
|
|
reader.read_to_end(&mut buffer).await?;
|
|
serde_cbor::de::from_reader(buffer.as_slice())
|
|
.map_err(anyhow::Error::from)
|
|
.with_kind(crate::ErrorKind::Deserialization)
|
|
}
|
|
|
|
pub async fn from_json_async_reader<T, R>(mut reader: R) -> Result<T, crate::Error>
|
|
where
|
|
T: for<'de> serde::Deserialize<'de>,
|
|
R: AsyncRead + Unpin,
|
|
{
|
|
let mut buffer = Vec::new();
|
|
reader.read_to_end(&mut buffer).await?;
|
|
serde_json::from_slice(&buffer)
|
|
.map_err(anyhow::Error::from)
|
|
.with_kind(crate::ErrorKind::Deserialization)
|
|
}
|
|
|
|
pub async fn to_json_async_writer<T, W>(mut writer: W, value: &T) -> Result<(), crate::Error>
|
|
where
|
|
T: serde::Serialize,
|
|
W: AsyncWrite + Unpin,
|
|
{
|
|
let buffer = serde_json::to_string(value).with_kind(crate::ErrorKind::Serialization)?;
|
|
writer.write_all(&buffer.as_bytes()).await?;
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn to_json_pretty_async_writer<T, W>(mut writer: W, value: &T) -> Result<(), crate::Error>
|
|
where
|
|
T: serde::Serialize,
|
|
W: AsyncWrite + Unpin,
|
|
{
|
|
let mut buffer =
|
|
serde_json::to_string_pretty(value).with_kind(crate::ErrorKind::Serialization)?;
|
|
buffer.push_str("\n");
|
|
writer.write_all(&buffer.as_bytes()).await?;
|
|
Ok(())
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
pub trait Invoke {
|
|
async fn invoke(&mut self, error_kind: crate::ErrorKind) -> Result<Vec<u8>, Error>;
|
|
}
|
|
#[async_trait::async_trait]
|
|
impl Invoke for tokio::process::Command {
|
|
async fn invoke(&mut self, error_kind: crate::ErrorKind) -> Result<Vec<u8>, Error> {
|
|
self.stdout(Stdio::piped());
|
|
self.stderr(Stdio::piped());
|
|
let res = self.output().await?;
|
|
crate::ensure_code!(
|
|
res.status.success(),
|
|
error_kind,
|
|
"{}",
|
|
std::str::from_utf8(&res.stderr).unwrap_or("Unknown Error")
|
|
);
|
|
Ok(res.stdout)
|
|
}
|
|
}
|
|
|
|
pub trait Apply: Sized {
|
|
fn apply<O, F: FnOnce(Self) -> O>(self, func: F) -> O {
|
|
func(self)
|
|
}
|
|
}
|
|
|
|
pub trait ApplyRef {
|
|
fn apply_ref<O, F: FnOnce(&Self) -> O>(&self, func: F) -> O {
|
|
func(&self)
|
|
}
|
|
|
|
fn apply_mut<O, F: FnOnce(&mut Self) -> O>(&mut self, func: F) -> O {
|
|
func(self)
|
|
}
|
|
}
|
|
|
|
impl<T> Apply for T {}
|
|
impl<T> ApplyRef for T {}
|
|
|
|
pub fn deserialize_from_str<
|
|
'de,
|
|
D: serde::de::Deserializer<'de>,
|
|
T: FromStr<Err = E>,
|
|
E: std::fmt::Display,
|
|
>(
|
|
deserializer: D,
|
|
) -> std::result::Result<T, D::Error> {
|
|
struct Visitor<T: FromStr<Err = E>, E>(std::marker::PhantomData<T>);
|
|
impl<'de, T: FromStr<Err = Err>, Err: std::fmt::Display> serde::de::Visitor<'de>
|
|
for Visitor<T, Err>
|
|
{
|
|
type Value = T;
|
|
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(formatter, "a parsable string")
|
|
}
|
|
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
v.parse().map_err(|e| serde::de::Error::custom(e))
|
|
}
|
|
}
|
|
deserializer.deserialize_str(Visitor(std::marker::PhantomData))
|
|
}
|
|
|
|
pub fn deserialize_from_str_opt<
|
|
'de,
|
|
D: serde::de::Deserializer<'de>,
|
|
T: FromStr<Err = E>,
|
|
E: std::fmt::Display,
|
|
>(
|
|
deserializer: D,
|
|
) -> std::result::Result<Option<T>, D::Error> {
|
|
struct Visitor<T: FromStr<Err = E>, E>(std::marker::PhantomData<T>);
|
|
impl<'de, T: FromStr<Err = Err>, Err: std::fmt::Display> serde::de::Visitor<'de>
|
|
for Visitor<T, Err>
|
|
{
|
|
type Value = Option<T>;
|
|
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(formatter, "a parsable string")
|
|
}
|
|
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
v.parse().map(Some).map_err(|e| serde::de::Error::custom(e))
|
|
}
|
|
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
|
where
|
|
D: serde::de::Deserializer<'de>,
|
|
{
|
|
deserializer.deserialize_str(Visitor(std::marker::PhantomData))
|
|
}
|
|
fn visit_none<E>(self) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(None)
|
|
}
|
|
fn visit_unit<E>(self) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(None)
|
|
}
|
|
}
|
|
deserializer.deserialize_any(Visitor(std::marker::PhantomData))
|
|
}
|
|
|
|
pub fn serialize_display<T: std::fmt::Display, S: Serializer>(
|
|
t: &T,
|
|
serializer: S,
|
|
) -> Result<S::Ok, S::Error> {
|
|
String::serialize(&t.to_string(), serializer)
|
|
}
|
|
|
|
pub fn serialize_display_opt<T: std::fmt::Display, S: Serializer>(
|
|
t: &Option<T>,
|
|
serializer: S,
|
|
) -> Result<S::Ok, S::Error> {
|
|
Option::<String>::serialize(&t.as_ref().map(|t| t.to_string()), serializer)
|
|
}
|
|
|
|
pub async fn daemon<F: FnMut() -> Fut, Fut: Future<Output = ()> + Send + 'static>(
|
|
mut f: F,
|
|
cooldown: std::time::Duration,
|
|
mut shutdown: tokio::sync::broadcast::Receiver<Option<Shutdown>>,
|
|
) -> Result<(), anyhow::Error> {
|
|
while matches!(
|
|
shutdown.try_recv(),
|
|
Err(tokio::sync::broadcast::error::TryRecvError::Empty)
|
|
) {
|
|
match tokio::spawn(f()).await {
|
|
Err(e) if e.is_panic() => return Err(anyhow!("daemon panicked!")),
|
|
_ => (),
|
|
}
|
|
tokio::time::sleep(cooldown).await
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub trait SOption<T> {}
|
|
pub struct SSome<T>(T);
|
|
impl<T> SSome<T> {
|
|
pub fn into(self) -> T {
|
|
self.0
|
|
}
|
|
}
|
|
impl<T> From<T> for SSome<T> {
|
|
fn from(t: T) -> Self {
|
|
SSome(t)
|
|
}
|
|
}
|
|
impl<T> SOption<T> for SSome<T> {}
|
|
pub struct SNone<T>(PhantomData<T>);
|
|
impl<T> SNone<T> {
|
|
pub fn new() -> Self {
|
|
SNone(PhantomData)
|
|
}
|
|
}
|
|
impl<T> SOption<T> for SNone<T> {}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
#[serde(untagged)]
|
|
pub enum ValuePrimative {
|
|
Null,
|
|
Boolean(bool),
|
|
String(String),
|
|
Number(serde_json::Number),
|
|
}
|
|
impl<'de> serde::de::Deserialize<'de> for ValuePrimative {
|
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
where
|
|
D: serde::de::Deserializer<'de>,
|
|
{
|
|
struct Visitor;
|
|
impl<'de> serde::de::Visitor<'de> for Visitor {
|
|
type Value = ValuePrimative;
|
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
write!(formatter, "a JSON primative value")
|
|
}
|
|
fn visit_unit<E>(self) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::Null)
|
|
}
|
|
fn visit_none<E>(self) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::Null)
|
|
}
|
|
fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::Boolean(v))
|
|
}
|
|
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::String(v.to_owned()))
|
|
}
|
|
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::String(v))
|
|
}
|
|
fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::Number(
|
|
serde_json::Number::from_f64(v as f64).ok_or_else(|| {
|
|
serde::de::Error::invalid_value(
|
|
serde::de::Unexpected::Float(v as f64),
|
|
&"a finite number",
|
|
)
|
|
})?,
|
|
))
|
|
}
|
|
fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::Number(
|
|
serde_json::Number::from_f64(v).ok_or_else(|| {
|
|
serde::de::Error::invalid_value(
|
|
serde::de::Unexpected::Float(v),
|
|
&"a finite number",
|
|
)
|
|
})?,
|
|
))
|
|
}
|
|
fn visit_u8<E>(self, v: u8) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::Number(v.into()))
|
|
}
|
|
fn visit_u16<E>(self, v: u16) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::Number(v.into()))
|
|
}
|
|
fn visit_u32<E>(self, v: u32) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::Number(v.into()))
|
|
}
|
|
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::Number(v.into()))
|
|
}
|
|
fn visit_i8<E>(self, v: i8) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::Number(v.into()))
|
|
}
|
|
fn visit_i16<E>(self, v: i16) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::Number(v.into()))
|
|
}
|
|
fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::Number(v.into()))
|
|
}
|
|
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(ValuePrimative::Number(v.into()))
|
|
}
|
|
}
|
|
deserializer.deserialize_any(Visitor)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct Version {
|
|
version: emver::Version,
|
|
string: String,
|
|
}
|
|
impl Version {
|
|
pub fn as_str(&self) -> &str {
|
|
self.string.as_str()
|
|
}
|
|
}
|
|
impl std::fmt::Display for Version {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}", self.string)
|
|
}
|
|
}
|
|
impl std::str::FromStr for Version {
|
|
type Err = <emver::Version as FromStr>::Err;
|
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
Ok(Version {
|
|
string: s.to_owned(),
|
|
version: s.parse()?,
|
|
})
|
|
}
|
|
}
|
|
impl From<emver::Version> for Version {
|
|
fn from(v: emver::Version) -> Self {
|
|
Version {
|
|
string: v.to_string(),
|
|
version: v,
|
|
}
|
|
}
|
|
}
|
|
impl From<Version> for emver::Version {
|
|
fn from(v: Version) -> Self {
|
|
v.version
|
|
}
|
|
}
|
|
impl Deref for Version {
|
|
type Target = emver::Version;
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.version
|
|
}
|
|
}
|
|
impl AsRef<emver::Version> for Version {
|
|
fn as_ref(&self) -> &emver::Version {
|
|
&self.version
|
|
}
|
|
}
|
|
impl AsRef<str> for Version {
|
|
fn as_ref(&self) -> &str {
|
|
self.as_str()
|
|
}
|
|
}
|
|
impl PartialEq for Version {
|
|
fn eq(&self, other: &Version) -> bool {
|
|
self.version.eq(&other.version)
|
|
}
|
|
}
|
|
impl Eq for Version {}
|
|
impl Hash for Version {
|
|
fn hash<H: Hasher>(&self, state: &mut H) {
|
|
self.version.hash(state)
|
|
}
|
|
}
|
|
impl<'de> Deserialize<'de> for Version {
|
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
where
|
|
D: Deserializer<'de>,
|
|
{
|
|
let string = String::deserialize(deserializer)?;
|
|
let version = emver::Version::from_str(&string).map_err(serde::de::Error::custom)?;
|
|
Ok(Self { string, version })
|
|
}
|
|
}
|
|
impl Serialize for Version {
|
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: Serializer,
|
|
{
|
|
self.string.serialize(serializer)
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
pub trait AsyncFileExt: Sized {
|
|
async fn maybe_open<P: AsRef<Path> + Send + Sync>(path: P) -> std::io::Result<Option<Self>>;
|
|
async fn delete<P: AsRef<Path> + Send + Sync>(path: P) -> std::io::Result<()>;
|
|
}
|
|
#[async_trait]
|
|
impl AsyncFileExt for File {
|
|
async fn maybe_open<P: AsRef<Path> + Send + Sync>(path: P) -> std::io::Result<Option<Self>> {
|
|
match File::open(path).await {
|
|
Ok(f) => Ok(Some(f)),
|
|
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
|
|
Err(e) => Err(e),
|
|
}
|
|
}
|
|
async fn delete<P: AsRef<Path> + Send + Sync>(path: P) -> std::io::Result<()> {
|
|
if let Ok(m) = tokio::fs::metadata(path.as_ref()).await {
|
|
if m.is_dir() {
|
|
tokio::fs::remove_dir_all(path).await
|
|
} else {
|
|
tokio::fs::remove_file(path).await
|
|
}
|
|
} else {
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct FmtWriter<W: std::fmt::Write>(W);
|
|
impl<W: std::fmt::Write> std::io::Write for FmtWriter<W> {
|
|
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
|
self.0
|
|
.write_str(
|
|
std::str::from_utf8(buf)
|
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?,
|
|
)
|
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
|
Ok(buf.len())
|
|
}
|
|
fn flush(&mut self) -> std::io::Result<()> {
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Copy, Debug, Deserialize, Serialize)]
|
|
#[serde(rename_all = "kebab-case")]
|
|
pub enum IoFormat {
|
|
Json,
|
|
JsonPretty,
|
|
Yaml,
|
|
Cbor,
|
|
Toml,
|
|
TomlPretty,
|
|
}
|
|
impl Default for IoFormat {
|
|
fn default() -> Self {
|
|
IoFormat::JsonPretty
|
|
}
|
|
}
|
|
impl std::fmt::Display for IoFormat {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
use IoFormat::*;
|
|
match self {
|
|
Json => write!(f, "JSON"),
|
|
JsonPretty => write!(f, "JSON (pretty)"),
|
|
Yaml => write!(f, "YAML"),
|
|
Cbor => write!(f, "CBOR"),
|
|
Toml => write!(f, "TOML"),
|
|
TomlPretty => write!(f, "TOML (pretty)"),
|
|
}
|
|
}
|
|
}
|
|
impl std::str::FromStr for IoFormat {
|
|
type Err = Error;
|
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
serde_json::from_value(Value::String(s.to_owned()))
|
|
.with_kind(crate::ErrorKind::Deserialization)
|
|
}
|
|
}
|
|
impl IoFormat {
|
|
pub fn to_writer<W: std::io::Write, T: Serialize>(
|
|
&self,
|
|
mut writer: W,
|
|
value: &T,
|
|
) -> Result<(), Error> {
|
|
match self {
|
|
IoFormat::Json => {
|
|
serde_json::to_writer(writer, value).with_kind(crate::ErrorKind::Serialization)
|
|
}
|
|
IoFormat::JsonPretty => serde_json::to_writer_pretty(writer, value)
|
|
.with_kind(crate::ErrorKind::Serialization),
|
|
IoFormat::Yaml => {
|
|
serde_yaml::to_writer(writer, value).with_kind(crate::ErrorKind::Serialization)
|
|
}
|
|
IoFormat::Cbor => serde_cbor::ser::into_writer(value, writer)
|
|
.with_kind(crate::ErrorKind::Serialization),
|
|
IoFormat::Toml => writer
|
|
.write_all(
|
|
&serde_toml::to_vec(
|
|
&serde_toml::Value::try_from(value)
|
|
.with_kind(crate::ErrorKind::Serialization)?,
|
|
)
|
|
.with_kind(crate::ErrorKind::Serialization)?,
|
|
)
|
|
.with_kind(crate::ErrorKind::Serialization),
|
|
IoFormat::TomlPretty => writer
|
|
.write_all(
|
|
serde_toml::to_string_pretty(
|
|
&serde_toml::Value::try_from(value)
|
|
.with_kind(crate::ErrorKind::Serialization)?,
|
|
)
|
|
.with_kind(crate::ErrorKind::Serialization)?
|
|
.as_bytes(),
|
|
)
|
|
.with_kind(crate::ErrorKind::Serialization),
|
|
}
|
|
}
|
|
pub fn to_vec<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, Error> {
|
|
match self {
|
|
IoFormat::Json => serde_json::to_vec(value).with_kind(crate::ErrorKind::Serialization),
|
|
IoFormat::JsonPretty => {
|
|
serde_json::to_vec_pretty(value).with_kind(crate::ErrorKind::Serialization)
|
|
}
|
|
IoFormat::Yaml => serde_yaml::to_vec(value).with_kind(crate::ErrorKind::Serialization),
|
|
IoFormat::Cbor => {
|
|
let mut res = Vec::new();
|
|
serde_cbor::ser::into_writer(value, &mut res)
|
|
.with_kind(crate::ErrorKind::Serialization)?;
|
|
Ok(res)
|
|
}
|
|
IoFormat::Toml => serde_toml::to_vec(
|
|
&serde_toml::Value::try_from(value).with_kind(crate::ErrorKind::Serialization)?,
|
|
)
|
|
.with_kind(crate::ErrorKind::Serialization),
|
|
IoFormat::TomlPretty => serde_toml::to_string_pretty(
|
|
&serde_toml::Value::try_from(value).with_kind(crate::ErrorKind::Serialization)?,
|
|
)
|
|
.map(|s| s.into_bytes())
|
|
.with_kind(crate::ErrorKind::Serialization),
|
|
}
|
|
}
|
|
/// BLOCKING
|
|
pub fn from_reader<R: std::io::Read, T: for<'de> Deserialize<'de>>(
|
|
&self,
|
|
mut reader: R,
|
|
) -> Result<T, Error> {
|
|
match self {
|
|
IoFormat::Json | IoFormat::JsonPretty => {
|
|
serde_json::from_reader(reader).with_kind(crate::ErrorKind::Deserialization)
|
|
}
|
|
IoFormat::Yaml => {
|
|
serde_yaml::from_reader(reader).with_kind(crate::ErrorKind::Deserialization)
|
|
}
|
|
IoFormat::Cbor => {
|
|
serde_cbor::de::from_reader(reader).with_kind(crate::ErrorKind::Deserialization)
|
|
}
|
|
IoFormat::Toml | IoFormat::TomlPretty => {
|
|
let mut s = String::new();
|
|
reader
|
|
.read_to_string(&mut s)
|
|
.with_kind(crate::ErrorKind::Deserialization)?;
|
|
serde_toml::from_str(&s).with_kind(crate::ErrorKind::Deserialization)
|
|
}
|
|
}
|
|
}
|
|
pub fn from_slice<T: for<'de> Deserialize<'de>>(&self, slice: &[u8]) -> Result<T, Error> {
|
|
match self {
|
|
IoFormat::Json | IoFormat::JsonPretty => {
|
|
serde_json::from_slice(slice).with_kind(crate::ErrorKind::Deserialization)
|
|
}
|
|
IoFormat::Yaml => {
|
|
serde_yaml::from_slice(slice).with_kind(crate::ErrorKind::Deserialization)
|
|
}
|
|
IoFormat::Cbor => {
|
|
serde_cbor::de::from_reader(slice).with_kind(crate::ErrorKind::Deserialization)
|
|
}
|
|
IoFormat::Toml | IoFormat::TomlPretty => {
|
|
serde_toml::from_slice(slice).with_kind(crate::ErrorKind::Deserialization)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn display_serializable<T: Serialize>(t: T, matches: &ArgMatches<'_>) {
|
|
let format = match matches.value_of("format").map(|f| f.parse()) {
|
|
Some(Ok(f)) => f,
|
|
Some(Err(_)) => {
|
|
eprintln!("unrecognized formatter");
|
|
exit(1)
|
|
}
|
|
None => IoFormat::default(),
|
|
};
|
|
format
|
|
.to_writer(std::io::stdout(), &t)
|
|
.expect("Error serializing result to stdout")
|
|
}
|
|
|
|
pub fn display_none<T>(_: T, _: &ArgMatches) {
|
|
()
|
|
}
|
|
|
|
pub fn parse_stdin_deserializable<T: for<'de> Deserialize<'de>>(
|
|
stdin: &mut std::io::Stdin,
|
|
matches: &ArgMatches<'_>,
|
|
) -> Result<T, Error> {
|
|
let format = match matches.value_of("format").map(|f| f.parse()) {
|
|
Some(Ok(f)) => f,
|
|
Some(Err(_)) => {
|
|
eprintln!("unrecognized formatter");
|
|
exit(1)
|
|
}
|
|
None => IoFormat::default(),
|
|
};
|
|
format.from_reader(stdin)
|
|
}
|
|
|
|
pub fn parse_duration(arg: &str, _: &ArgMatches<'_>) -> Result<Duration, Error> {
|
|
let units_idx = arg.find(|c: char| c.is_alphabetic()).ok_or_else(|| {
|
|
Error::new(
|
|
anyhow!("Must specify units for duration"),
|
|
crate::ErrorKind::Deserialization,
|
|
)
|
|
})?;
|
|
let (num, units) = arg.split_at(units_idx);
|
|
match units {
|
|
"d" if num.contains(".") => Ok(Duration::from_secs_f64(num.parse::<f64>()? * 86400_f64)),
|
|
"d" => Ok(Duration::from_secs(num.parse::<u64>()? * 86400)),
|
|
"h" if num.contains(".") => Ok(Duration::from_secs_f64(num.parse::<f64>()? * 3600_f64)),
|
|
"h" => Ok(Duration::from_secs(num.parse::<u64>()? * 3600)),
|
|
"m" if num.contains(".") => Ok(Duration::from_secs_f64(num.parse::<f64>()? * 60_f64)),
|
|
"m" => Ok(Duration::from_secs(num.parse::<u64>()? * 60)),
|
|
"s" if num.contains(".") => Ok(Duration::from_secs_f64(num.parse()?)),
|
|
"s" => Ok(Duration::from_secs(num.parse()?)),
|
|
"ms" => Ok(Duration::from_millis(num.parse()?)),
|
|
"us" => Ok(Duration::from_micros(num.parse()?)),
|
|
"ns" => Ok(Duration::from_nanos(num.parse()?)),
|
|
_ => Err(Error::new(
|
|
anyhow!("Invalid units for duration"),
|
|
crate::ErrorKind::Deserialization,
|
|
)),
|
|
}
|
|
}
|
|
|
|
pub struct Container<T>(RwLock<Option<T>>);
|
|
impl<T> Container<T> {
|
|
pub fn new(value: Option<T>) -> Self {
|
|
Container(RwLock::new(value))
|
|
}
|
|
pub async fn set(&self, value: T) -> Option<T> {
|
|
std::mem::replace(&mut *self.0.write().await, Some(value))
|
|
}
|
|
pub async fn take(&self) -> Option<T> {
|
|
std::mem::replace(&mut *self.0.write().await, None)
|
|
}
|
|
pub async fn is_empty(&self) -> bool {
|
|
self.0.read().await.is_none()
|
|
}
|
|
pub async fn drop(&self) {
|
|
*self.0.write().await = None;
|
|
}
|
|
}
|
|
|
|
pub struct HashWriter<H: Digest, W: std::io::Write> {
|
|
hasher: H,
|
|
writer: W,
|
|
}
|
|
impl<H: Digest, W: std::io::Write> HashWriter<H, W> {
|
|
pub fn new(hasher: H, writer: W) -> Self {
|
|
HashWriter { hasher, writer }
|
|
}
|
|
pub fn finish(self) -> (H, W) {
|
|
(self.hasher, self.writer)
|
|
}
|
|
}
|
|
impl<H: Digest, W: std::io::Write> std::io::Write for HashWriter<H, W> {
|
|
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
|
let written = self.writer.write(buf)?;
|
|
self.hasher.update(&buf[..written]);
|
|
Ok(written)
|
|
}
|
|
fn flush(&mut self) -> std::io::Result<()> {
|
|
self.writer.flush()
|
|
}
|
|
}
|
|
impl<H: Digest, W: std::io::Write> std::ops::Deref for HashWriter<H, W> {
|
|
type Target = W;
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.writer
|
|
}
|
|
}
|
|
impl<H: Digest, W: std::io::Write> std::ops::DerefMut for HashWriter<H, W> {
|
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
|
&mut self.writer
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
|
pub struct Port(pub u16);
|
|
impl<'de> Deserialize<'de> for Port {
|
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
where
|
|
D: Deserializer<'de>,
|
|
{
|
|
//TODO: if number, be permissive
|
|
deserialize_from_str(deserializer).map(Port)
|
|
}
|
|
}
|
|
impl Serialize for Port {
|
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: Serializer,
|
|
{
|
|
serialize_display(&self.0, serializer)
|
|
}
|
|
}
|
|
|
|
pub trait IntoDoubleEndedIterator<T>: IntoIterator<Item = T> {
|
|
type IntoIter: Iterator<Item = T> + DoubleEndedIterator;
|
|
fn into_iter(self) -> <Self as IntoDoubleEndedIterator<T>>::IntoIter;
|
|
}
|
|
impl<T, U> IntoDoubleEndedIterator<U> for T
|
|
where
|
|
T: IntoIterator<Item = U>,
|
|
<T as IntoIterator>::IntoIter: DoubleEndedIterator,
|
|
{
|
|
type IntoIter = <T as IntoIterator>::IntoIter;
|
|
fn into_iter(self) -> <Self as IntoDoubleEndedIterator<U>>::IntoIter {
|
|
IntoIterator::into_iter(self)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct Reversible<T, Container = Vec<T>>
|
|
where
|
|
for<'a> &'a Container: IntoDoubleEndedIterator<&'a T>,
|
|
{
|
|
reversed: bool,
|
|
data: Container,
|
|
phantom: PhantomData<T>,
|
|
}
|
|
impl<T, Container> Reversible<T, Container>
|
|
where
|
|
for<'a> &'a Container: IntoDoubleEndedIterator<&'a T>,
|
|
{
|
|
pub fn new(data: Container) -> Self {
|
|
Reversible {
|
|
reversed: false,
|
|
data,
|
|
phantom: PhantomData,
|
|
}
|
|
}
|
|
|
|
pub fn reverse(&mut self) {
|
|
self.reversed = !self.reversed
|
|
}
|
|
|
|
pub fn iter(
|
|
&self,
|
|
) -> itertools::Either<
|
|
<&Container as IntoDoubleEndedIterator<&T>>::IntoIter,
|
|
std::iter::Rev<<&Container as IntoDoubleEndedIterator<&T>>::IntoIter>,
|
|
> {
|
|
let iter = IntoDoubleEndedIterator::into_iter(&self.data);
|
|
if self.reversed {
|
|
itertools::Either::Right(iter.rev())
|
|
} else {
|
|
itertools::Either::Left(iter)
|
|
}
|
|
}
|
|
}
|
|
impl<T, Container> Serialize for Reversible<T, Container>
|
|
where
|
|
for<'a> &'a Container: IntoDoubleEndedIterator<&'a T>,
|
|
T: Serialize,
|
|
{
|
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: Serializer,
|
|
{
|
|
use serde::ser::SerializeSeq;
|
|
|
|
let iter = IntoDoubleEndedIterator::into_iter(&self.data);
|
|
let mut seq_ser = serializer.serialize_seq(match iter.size_hint() {
|
|
(lower, Some(upper)) if lower == upper => Some(upper),
|
|
_ => None,
|
|
})?;
|
|
if self.reversed {
|
|
for elem in iter.rev() {
|
|
seq_ser.serialize_element(elem)?;
|
|
}
|
|
} else {
|
|
for elem in iter {
|
|
seq_ser.serialize_element(elem)?;
|
|
}
|
|
}
|
|
seq_ser.end()
|
|
}
|
|
}
|
|
impl<'de, T, Container> Deserialize<'de> for Reversible<T, Container>
|
|
where
|
|
for<'a> &'a Container: IntoDoubleEndedIterator<&'a T>,
|
|
Container: Deserialize<'de>,
|
|
{
|
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
where
|
|
D: Deserializer<'de>,
|
|
{
|
|
Ok(Reversible::new(Deserialize::deserialize(deserializer)?))
|
|
}
|
|
fn deserialize_in_place<D>(deserializer: D, place: &mut Self) -> Result<(), D::Error>
|
|
where
|
|
D: Deserializer<'de>,
|
|
{
|
|
Deserialize::deserialize_in_place(deserializer, &mut place.data)
|
|
}
|
|
}
|
|
|
|
#[pin_project::pin_project(PinnedDrop)]
|
|
pub struct NonDetachingJoinHandle<T>(#[pin] JoinHandle<T>);
|
|
impl<T> From<JoinHandle<T>> for NonDetachingJoinHandle<T> {
|
|
fn from(t: JoinHandle<T>) -> Self {
|
|
NonDetachingJoinHandle(t)
|
|
}
|
|
}
|
|
#[pin_project::pinned_drop]
|
|
impl<T> PinnedDrop for NonDetachingJoinHandle<T> {
|
|
fn drop(self: std::pin::Pin<&mut Self>) {
|
|
let this = self.project();
|
|
this.0.into_ref().get_ref().abort()
|
|
}
|
|
}
|
|
impl<T> Future for NonDetachingJoinHandle<T> {
|
|
type Output = Result<T, JoinError>;
|
|
fn poll(
|
|
self: std::pin::Pin<&mut Self>,
|
|
cx: &mut std::task::Context<'_>,
|
|
) -> std::task::Poll<Self::Output> {
|
|
let this = self.project();
|
|
this.0.poll(cx)
|
|
}
|
|
}
|