rename frontend to web and update contributing guide (#2509)

* rename frontend to web and update contributing guide

* rename this time

* fix build

* restructure rust code

* update documentation

* update descriptions

* Update CONTRIBUTING.md

Co-authored-by: J H <2364004+Blu-J@users.noreply.github.com>

---------

Co-authored-by: Aiden McClelland <me@drbonez.dev>
Co-authored-by: Aiden McClelland <3732071+dr-bonez@users.noreply.github.com>
Co-authored-by: J H <2364004+Blu-J@users.noreply.github.com>
This commit is contained in:
Matt Hill
2023-11-13 14:22:23 -07:00
committed by GitHub
parent 871f78b570
commit 86567e7fa5
968 changed files with 812 additions and 6672 deletions

View File

@@ -0,0 +1,58 @@
use std::fs::File;
use std::path::{Path, PathBuf};
use patch_db::Value;
use serde::Deserialize;
use crate::prelude::*;
use crate::util::serde::IoFormat;
use crate::{Config, Error};
pub const DEVICE_CONFIG_PATH: &str = "/media/embassy/config/config.yaml";
pub const CONFIG_PATH: &str = "/etc/embassy/config.yaml";
pub const CONFIG_PATH_LOCAL: &str = ".embassy/config.yaml";
pub fn local_config_path() -> Option<PathBuf> {
if let Ok(home) = std::env::var("HOME") {
Some(Path::new(&home).join(CONFIG_PATH_LOCAL))
} else {
None
}
}
/// BLOCKING
pub fn load_config_from_paths<'a, T: for<'de> Deserialize<'de>>(
paths: impl IntoIterator<Item = impl AsRef<Path>>,
) -> Result<T, Error> {
let mut config = Default::default();
for path in paths {
if path.as_ref().exists() {
let format: IoFormat = path
.as_ref()
.extension()
.and_then(|s| s.to_str())
.map(|f| f.parse())
.transpose()?
.unwrap_or_default();
let new = format.from_reader(File::open(path)?)?;
config = merge_configs(config, new);
}
}
from_value(Value::Object(config))
}
pub fn merge_configs(mut first: Config, second: Config) -> Config {
for (k, v) in second.into_iter() {
let new = match first.remove(&k) {
None => v,
Some(old) => match (old, v) {
(Value::Object(first), Value::Object(second)) => {
Value::Object(merge_configs(first, second))
}
(first, _) => first,
},
};
first.insert(k, new);
}
first
}

View File

@@ -0,0 +1,125 @@
use std::borrow::Cow;
use std::collections::BTreeSet;
use imbl::OrdMap;
use tokio::process::Command;
use crate::prelude::*;
use crate::util::Invoke;
pub const GOVERNOR_PERFORMANCE: Governor = Governor(Cow::Borrowed("performance"));
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct Governor(Cow<'static, str>);
impl std::fmt::Display for Governor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::ops::Deref for Governor {
type Target = str;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
impl std::borrow::Borrow<str> for Governor {
fn borrow(&self) -> &str {
&**self
}
}
pub async fn get_available_governors() -> Result<BTreeSet<Governor>, Error> {
let raw = String::from_utf8(
Command::new("cpupower")
.arg("frequency-info")
.arg("-g")
.invoke(ErrorKind::CpuSettings)
.await?,
)?;
let mut for_cpu: OrdMap<u32, BTreeSet<Governor>> = OrdMap::new();
let mut current_cpu = None;
for line in raw.lines() {
if line.starts_with("analyzing") {
current_cpu = Some(
sscanf::sscanf!(line, "analyzing CPU {u32}:")
.map_err(|e| eyre!("{e}"))
.with_kind(ErrorKind::ParseSysInfo)?,
);
} else if let Some(rest) = line
.trim()
.strip_prefix("available cpufreq governors:")
.map(|s| s.trim())
{
if rest != "Not Available" {
for_cpu
.entry(current_cpu.ok_or_else(|| {
Error::new(
eyre!("governors listed before cpu"),
ErrorKind::ParseSysInfo,
)
})?)
.or_default()
.extend(
rest.split_ascii_whitespace()
.map(|g| Governor(Cow::Owned(g.to_owned()))),
);
}
}
}
Ok(for_cpu
.into_iter()
.fold(None, |acc: Option<BTreeSet<Governor>>, (_, x)| {
if let Some(acc) = acc {
Some(acc.intersection(&x).cloned().collect())
} else {
Some(x)
}
})
.unwrap_or_default()) // include only governors available for ALL cpus
}
pub async fn current_governor() -> Result<Option<Governor>, Error> {
let Some(raw) = Command::new("cpupower")
.arg("frequency-info")
.arg("-p")
.invoke(ErrorKind::CpuSettings)
.await
.and_then(|s| Ok(Some(String::from_utf8(s)?)))
.or_else(|e| {
if e.source
.to_string()
.contains("Unable to determine current policy")
{
Ok(None)
} else {
Err(e)
}
})?
else {
return Ok(None);
};
for line in raw.lines() {
if let Some(governor) = line
.trim()
.strip_prefix("The governor \"")
.and_then(|s| s.strip_suffix("\" may decide which speed to use"))
{
return Ok(Some(Governor(Cow::Owned(governor.to_owned()))));
}
}
Err(Error::new(
eyre!("Failed to parse cpupower output:\n{raw}"),
ErrorKind::ParseSysInfo,
))
}
pub async fn set_governor(governor: &Governor) -> Result<(), Error> {
Command::new("cpupower")
.arg("frequency-set")
.arg("-g")
.arg(&*governor.0)
.invoke(ErrorKind::CpuSettings)
.await?;
Ok(())
}

View File

@@ -0,0 +1,9 @@
use ed25519_dalek::{SecretKey, EXPANDED_SECRET_KEY_LENGTH};
#[inline]
pub fn ed25519_expand_key(key: &SecretKey) -> [u8; EXPANDED_SECRET_KEY_LENGTH] {
ed25519_dalek_v1::ExpandedSecretKey::from(
&ed25519_dalek_v1::SecretKey::from_bytes(key).unwrap(),
)
.to_bytes()
}

View File

@@ -0,0 +1,239 @@
use std::net::Ipv4Addr;
use std::time::Duration;
use models::{Error, ErrorKind, PackageId, ResultExt, Version};
use nix::sys::signal::Signal;
use tokio::process::Command;
use crate::util::Invoke;
#[cfg(feature = "docker")]
pub const CONTAINER_TOOL: &str = "docker";
#[cfg(not(feature = "docker"))]
pub const CONTAINER_TOOL: &str = "podman";
#[cfg(feature = "docker")]
pub const CONTAINER_DATADIR: &str = "/var/lib/docker";
#[cfg(not(feature = "docker"))]
pub const CONTAINER_DATADIR: &str = "/var/lib/containers";
pub struct DockerImageSha(String);
// docker images start9/${package}/*:${version} -q --no-trunc
pub async fn images_for(
package: &PackageId,
version: &Version,
) -> Result<Vec<DockerImageSha>, Error> {
Ok(String::from_utf8(
Command::new(CONTAINER_TOOL)
.arg("images")
.arg(format!("start9/{package}/*:{version}"))
.arg("--no-trunc")
.arg("-q")
.invoke(ErrorKind::Docker)
.await?,
)?
.lines()
.map(|l| DockerImageSha(l.trim().to_owned()))
.collect())
}
// docker rmi -f ${sha}
pub async fn remove_image(sha: &DockerImageSha) -> Result<(), Error> {
match Command::new(CONTAINER_TOOL)
.arg("rmi")
.arg("-f")
.arg(&sha.0)
.invoke(ErrorKind::Docker)
.await
.map(|_| ())
{
Err(e)
if e.source
.to_string()
.to_ascii_lowercase()
.contains("no such image") =>
{
Ok(())
}
a => a,
}?;
Ok(())
}
// docker image prune -f
pub async fn prune_images() -> Result<(), Error> {
Command::new(CONTAINER_TOOL)
.arg("image")
.arg("prune")
.arg("-f")
.invoke(ErrorKind::Docker)
.await?;
Ok(())
}
// docker container inspect ${name} --format '{{.NetworkSettings.Networks.start9.IPAddress}}'
pub async fn get_container_ip(name: &str) -> Result<Option<Ipv4Addr>, Error> {
match Command::new(CONTAINER_TOOL)
.arg("container")
.arg("inspect")
.arg(name)
.arg("--format")
.arg("{{.NetworkSettings.Networks.start9.IPAddress}}")
.invoke(ErrorKind::Docker)
.await
{
Err(e)
if e.source
.to_string()
.to_ascii_lowercase()
.contains("no such container") =>
{
Ok(None)
}
Err(e) => Err(e),
Ok(a) => {
let out = std::str::from_utf8(&a)?.trim();
if out.is_empty() {
Ok(None)
} else {
Ok(Some({
out.parse()
.with_ctx(|_| (ErrorKind::ParseNetAddress, out.to_string()))?
}))
}
}
}
}
// docker stop -t ${timeout} -s ${signal} ${name}
pub async fn stop_container(
name: &str,
timeout: Option<Duration>,
signal: Option<Signal>,
) -> Result<(), Error> {
let mut cmd = Command::new(CONTAINER_TOOL);
cmd.arg("stop");
if let Some(dur) = timeout {
cmd.arg("-t").arg(dur.as_secs().to_string());
}
if let Some(sig) = signal {
cmd.arg("-s").arg(sig.to_string());
}
cmd.arg(name);
match cmd.invoke(ErrorKind::Docker).await {
Ok(_) => Ok(()),
Err(mut e)
if e.source
.to_string()
.to_ascii_lowercase()
.contains("no such container") =>
{
e.kind = ErrorKind::NotFound;
Err(e)
}
Err(e) => Err(e),
}
}
// docker kill -s ${signal} ${name}
pub async fn kill_container(name: &str, signal: Option<Signal>) -> Result<(), Error> {
let mut cmd = Command::new(CONTAINER_TOOL);
cmd.arg("kill");
if let Some(sig) = signal {
cmd.arg("-s").arg(sig.to_string());
}
cmd.arg(name);
match cmd.invoke(ErrorKind::Docker).await {
Ok(_) => Ok(()),
Err(mut e)
if e.source
.to_string()
.to_ascii_lowercase()
.contains("no such container") =>
{
e.kind = ErrorKind::NotFound;
Err(e)
}
Err(e) => Err(e),
}
}
// docker pause ${name}
pub async fn pause_container(name: &str) -> Result<(), Error> {
let mut cmd = Command::new(CONTAINER_TOOL);
cmd.arg("pause");
cmd.arg(name);
match cmd.invoke(ErrorKind::Docker).await {
Ok(_) => Ok(()),
Err(mut e)
if e.source
.to_string()
.to_ascii_lowercase()
.contains("no such container") =>
{
e.kind = ErrorKind::NotFound;
Err(e)
}
Err(e) => Err(e),
}
}
// docker unpause ${name}
pub async fn unpause_container(name: &str) -> Result<(), Error> {
let mut cmd = Command::new(CONTAINER_TOOL);
cmd.arg("unpause");
cmd.arg(name);
match cmd.invoke(ErrorKind::Docker).await {
Ok(_) => Ok(()),
Err(mut e)
if e.source
.to_string()
.to_ascii_lowercase()
.contains("no such container") =>
{
e.kind = ErrorKind::NotFound;
Err(e)
}
Err(e) => Err(e),
}
}
// docker rm -f ${name}
pub async fn remove_container(name: &str, force: bool) -> Result<(), Error> {
let mut cmd = Command::new(CONTAINER_TOOL);
cmd.arg("rm");
if force {
cmd.arg("-f");
}
cmd.arg(name);
match cmd.invoke(ErrorKind::Docker).await {
Ok(_) => Ok(()),
Err(e)
if e.source
.to_string()
.to_ascii_lowercase()
.contains("no such container") =>
{
Ok(())
}
Err(e) => Err(e),
}
}
// docker network create -d bridge --subnet ${subnet} --opt com.podman.network.bridge.name=${bridge_name}
pub async fn create_bridge_network(
name: &str,
subnet: &str,
bridge_name: &str,
) -> Result<(), Error> {
let mut cmd = Command::new(CONTAINER_TOOL);
cmd.arg("network").arg("create");
cmd.arg("-d").arg("bridge");
cmd.arg("--subnet").arg(subnet);
cmd.arg("--opt")
.arg(format!("com.docker.network.bridge.name={bridge_name}"));
cmd.arg(name);
cmd.invoke(ErrorKind::Docker).await?;
Ok(())
}

View File

@@ -0,0 +1,380 @@
use std::cmp::min;
use std::convert::TryFrom;
use std::fmt::Display;
use std::future::Future;
use std::io::Error as StdIOError;
use std::pin::Pin;
use std::task::{Context, Poll};
use color_eyre::eyre::eyre;
use futures::Stream;
use http::header::{ACCEPT_RANGES, CONTENT_LENGTH, RANGE};
use hyper::body::Bytes;
use pin_project::pin_project;
use reqwest::{Client, Url};
use tokio::io::{AsyncRead, AsyncSeek};
use crate::{Error, ResultExt};
#[pin_project]
pub struct HttpReader {
http_url: Url,
cursor_pos: usize,
http_client: Client,
total_bytes: usize,
range_unit: Option<RangeUnit>,
read_in_progress: ReadInProgress,
}
type InProgress = Pin<
Box<
dyn Future<
Output = Result<
Pin<
Box<
dyn Stream<Item = Result<Bytes, reqwest::Error>>
+ Send
+ Sync
+ 'static,
>,
>,
Error,
>,
> + Send
+ Sync
+ 'static,
>,
>;
enum ReadInProgress {
None,
InProgress(InProgress),
Complete(Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send + Sync + 'static>>),
}
impl ReadInProgress {
fn take(&mut self) -> Self {
std::mem::replace(self, Self::None)
}
}
// If we want to add support for units other than Accept-Ranges: bytes, we can use this enum
#[derive(Clone, Copy)]
enum RangeUnit {
Bytes,
}
impl Default for RangeUnit {
fn default() -> Self {
RangeUnit::Bytes
}
}
impl Display for RangeUnit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RangeUnit::Bytes => write!(f, "bytes"),
}
}
}
impl HttpReader {
pub async fn new(http_url: Url) -> Result<Self, Error> {
let http_client = Client::builder()
// .proxy(reqwest::Proxy::all("socks5h://127.0.0.1:9050").unwrap())
.build()
.with_kind(crate::ErrorKind::TLSInit)?;
// Make a head request so that we can get the file size and check for http range support.
let head_request = http_client
.head(http_url.clone())
.send()
.await
.with_kind(crate::ErrorKind::InvalidRequest)?;
let accept_ranges = head_request.headers().get(ACCEPT_RANGES);
let range_unit = match accept_ranges {
Some(range_type) => {
// as per rfc, header will contain data but not always UTF8 characters.
let value = range_type
.to_str()
.map_err(|err| Error::new(err, crate::ErrorKind::Utf8))?;
match value {
"bytes" => Some(RangeUnit::Bytes),
_ => {
return Err(Error::new(
eyre!(
"{} HTTP range downloading not supported with this unit {value}",
http_url
),
crate::ErrorKind::MissingHeader,
));
}
}
}
// None can mean just get entire contents, but we currently error out.
None => {
return Err(Error::new(
eyre!(
"{} HTTP range downloading not supported with this url",
http_url
),
crate::ErrorKind::MissingHeader,
))
}
};
let total_bytes_option = head_request.headers().get(CONTENT_LENGTH);
let total_bytes = match total_bytes_option {
Some(bytes) => bytes
.to_str()
.map_err(|err| Error::new(err, crate::ErrorKind::Utf8))?
.parse::<usize>()?,
None => {
return Err(Error::new(
eyre!("No content length headers for {}", http_url),
crate::ErrorKind::MissingHeader,
))
}
};
Ok(HttpReader {
http_url,
cursor_pos: 0,
http_client,
total_bytes,
range_unit,
read_in_progress: ReadInProgress::None,
})
}
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests
async fn get_range(
range_unit: Option<RangeUnit>,
http_client: Client,
http_url: Url,
start: usize,
len: usize,
total_bytes: usize,
) -> Result<
Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send + Sync + 'static>>,
Error,
> {
let end = min(start + len, total_bytes) - 1;
if start > end {
return Ok(Box::pin(futures::stream::empty()));
}
let data_range = format!("{}={}-{} ", range_unit.unwrap_or_default(), start, end);
let data_resp = http_client
.get(http_url)
.header(RANGE, data_range)
.send()
.await
.with_kind(crate::ErrorKind::Network)?
.error_for_status()
.with_kind(crate::ErrorKind::Network)?;
Ok(Box::pin(data_resp.bytes_stream()))
}
}
impl AsyncRead for HttpReader {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
fn poll_complete(
body: &mut Pin<
Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send + Sync + 'static>,
>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<Option<std::io::Result<usize>>> {
Poll::Ready(match futures::ready!(body.as_mut().poll_next(cx)) {
Some(Ok(bytes)) => {
if buf.remaining() < bytes.len() {
Some(Err(StdIOError::new(
std::io::ErrorKind::InvalidInput,
format!("more bytes returned than expected"),
)))
} else {
buf.put_slice(&*bytes);
Some(Ok(bytes.len()))
}
}
Some(Err(e)) => Some(Err(StdIOError::new(std::io::ErrorKind::Interrupted, e))),
None => None,
})
}
let this = self.project();
loop {
let mut in_progress = match this.read_in_progress.take() {
ReadInProgress::Complete(mut body) => match poll_complete(&mut body, cx, buf) {
Poll::Pending => {
*this.read_in_progress = ReadInProgress::Complete(body);
return Poll::Pending;
}
Poll::Ready(Some(Ok(len))) => {
*this.read_in_progress = ReadInProgress::Complete(body);
*this.cursor_pos += len;
return Poll::Ready(Ok(()));
}
Poll::Ready(res) => {
if let Some(Err(e)) = res {
tracing::error!(
"Error reading bytes from {}: {}, attempting to resume download",
this.http_url,
e
);
tracing::debug!("{:?}", e);
}
if *this.cursor_pos == *this.total_bytes {
return Poll::Ready(Ok(()));
}
continue;
}
},
ReadInProgress::None => Box::pin(HttpReader::get_range(
*this.range_unit,
this.http_client.clone(),
this.http_url.clone(),
*this.cursor_pos,
buf.remaining(),
*this.total_bytes,
)),
ReadInProgress::InProgress(fut) => fut,
};
let res_poll = in_progress.as_mut().poll(cx);
match res_poll {
Poll::Ready(result) => match result {
Ok(body) => {
*this.read_in_progress = ReadInProgress::Complete(body);
}
Err(err) => {
break Poll::Ready(Err(StdIOError::new(
std::io::ErrorKind::Interrupted,
Box::<dyn std::error::Error + Send + Sync>::from(err.source),
)));
}
},
Poll::Pending => {
*this.read_in_progress = ReadInProgress::InProgress(in_progress);
break Poll::Pending;
}
}
}
}
}
impl AsyncSeek for HttpReader {
fn start_seek(self: Pin<&mut Self>, position: std::io::SeekFrom) -> std::io::Result<()> {
let this = self.project();
this.read_in_progress.take(); // invalidate any existing reads
match position {
std::io::SeekFrom::Start(offset) => {
let pos_res = usize::try_from(offset);
match pos_res {
Ok(pos) => {
if pos > *this.total_bytes {
StdIOError::new(
std::io::ErrorKind::InvalidInput,
format!(
"The offset: {} cannot be greater than {} bytes",
pos, *this.total_bytes
),
);
}
*this.cursor_pos = pos;
}
Err(err) => return Err(StdIOError::new(std::io::ErrorKind::InvalidInput, err)),
}
Ok(())
}
std::io::SeekFrom::Current(offset) => {
// We explicitly check if we read before byte 0.
let new_pos = i64::try_from(*this.cursor_pos)
.map_err(|err| StdIOError::new(std::io::ErrorKind::InvalidInput, err))?
+ offset;
if new_pos < 0 {
return Err(StdIOError::new(
std::io::ErrorKind::InvalidInput,
"Can't read before byte 0",
));
}
*this.cursor_pos = new_pos as usize;
Ok(())
}
std::io::SeekFrom::End(offset) => {
// We explicitly check if we read before byte 0.
let new_pos = i64::try_from(*this.total_bytes)
.map_err(|err| StdIOError::new(std::io::ErrorKind::InvalidInput, err))?
+ offset;
if new_pos < 0 {
return Err(StdIOError::new(
std::io::ErrorKind::InvalidInput,
"Can't read before byte 0",
));
}
*this.cursor_pos = new_pos as usize;
Ok(())
}
}
}
fn poll_complete(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<u64>> {
Poll::Ready(Ok(self.cursor_pos as u64))
}
}
#[tokio::test]
async fn main_test() {
let http_url = Url::parse("https://start9.com/latest/_static/css/main.css").unwrap();
println!("Getting this resource: {}", http_url);
let mut test_reader = HttpReader::new(http_url).await.unwrap();
let mut buf = Vec::new();
tokio::io::copy(&mut test_reader, &mut buf).await.unwrap();
assert_eq!(buf.len(), test_reader.total_bytes)
}
#[tokio::test]
#[ignore]
async fn s9pk_test() {
use tokio::io::BufReader;
let http_url = Url::parse("http://qhc6ac47cytstejcepk2ia3ipadzjhlkc5qsktsbl4e7u2krfmfuaqqd.onion/content/files/2022/09/ghost.s9pk").unwrap();
println!("Getting this resource: {}", http_url);
let test_reader =
BufReader::with_capacity(1024 * 1024, HttpReader::new(http_url).await.unwrap());
let mut s9pk = crate::s9pk::reader::S9pkReader::from_reader(test_reader, false)
.await
.unwrap();
let manifest = s9pk.manifest().await.unwrap();
assert_eq!(&manifest.id.to_string(), "ghost");
}

671
core/startos/src/util/io.rs Normal file
View File

@@ -0,0 +1,671 @@
use std::future::Future;
use std::io::Cursor;
use std::os::unix::prelude::MetadataExt;
use std::path::Path;
use std::sync::atomic::AtomicU64;
use std::task::Poll;
use std::time::Duration;
use futures::future::{BoxFuture, Fuse};
use futures::{AsyncSeek, FutureExt, TryStreamExt};
use helpers::NonDetachingJoinHandle;
use nix::unistd::{Gid, Uid};
use tokio::io::{
duplex, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf, WriteHalf,
};
use tokio::net::TcpStream;
use tokio::time::{Instant, Sleep};
use crate::ResultExt;
pub trait AsyncReadSeek: AsyncRead + AsyncSeek {}
impl<T: AsyncRead + AsyncSeek> AsyncReadSeek for T {}
#[derive(Clone, Debug)]
pub struct AsyncCompat<T>(pub T);
impl<T> futures::io::AsyncRead for AsyncCompat<T>
where
T: tokio::io::AsyncRead,
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<std::io::Result<usize>> {
let mut read_buf = ReadBuf::new(buf);
tokio::io::AsyncRead::poll_read(
unsafe { self.map_unchecked_mut(|a| &mut a.0) },
cx,
&mut read_buf,
)
.map(|res| res.map(|_| read_buf.filled().len()))
}
}
impl<T> tokio::io::AsyncRead for AsyncCompat<T>
where
T: futures::io::AsyncRead,
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf,
) -> std::task::Poll<std::io::Result<()>> {
futures::io::AsyncRead::poll_read(
unsafe { self.map_unchecked_mut(|a| &mut a.0) },
cx,
buf.initialize_unfilled(),
)
.map(|res| res.map(|len| buf.set_filled(len)))
}
}
impl<T> futures::io::AsyncWrite for AsyncCompat<T>
where
T: tokio::io::AsyncWrite,
{
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
tokio::io::AsyncWrite::poll_write(unsafe { self.map_unchecked_mut(|a| &mut a.0) }, cx, buf)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
tokio::io::AsyncWrite::poll_flush(unsafe { self.map_unchecked_mut(|a| &mut a.0) }, cx)
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
tokio::io::AsyncWrite::poll_shutdown(unsafe { self.map_unchecked_mut(|a| &mut a.0) }, cx)
}
}
impl<T> tokio::io::AsyncWrite for AsyncCompat<T>
where
T: futures::io::AsyncWrite,
{
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
futures::io::AsyncWrite::poll_write(
unsafe { self.map_unchecked_mut(|a| &mut a.0) },
cx,
buf,
)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
futures::io::AsyncWrite::poll_flush(unsafe { self.map_unchecked_mut(|a| &mut a.0) }, cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
futures::io::AsyncWrite::poll_close(unsafe { self.map_unchecked_mut(|a| &mut a.0) }, cx)
}
}
pub async fn from_yaml_async_reader<T, R>(mut reader: R) -> Result<T, crate::Error>
where
T: for<'de> serde::Deserialize<'de>,
R: AsyncRead + Unpin,
{
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).await?;
serde_yaml::from_slice(&buffer)
.map_err(color_eyre::eyre::Error::from)
.with_kind(crate::ErrorKind::Deserialization)
}
pub async fn to_yaml_async_writer<T, W>(mut writer: W, value: &T) -> Result<(), crate::Error>
where
T: serde::Serialize,
W: AsyncWrite + Unpin,
{
let mut buffer = serde_yaml::to_string(value)
.with_kind(crate::ErrorKind::Serialization)?
.into_bytes();
buffer.extend_from_slice(b"\n");
writer.write_all(&buffer).await?;
Ok(())
}
pub async fn from_toml_async_reader<T, R>(mut reader: R) -> Result<T, crate::Error>
where
T: for<'de> serde::Deserialize<'de>,
R: AsyncRead + Unpin,
{
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).await?;
serde_toml::from_str(std::str::from_utf8(&buffer)?)
.map_err(color_eyre::eyre::Error::from)
.with_kind(crate::ErrorKind::Deserialization)
}
pub async fn to_toml_async_writer<T, W>(mut writer: W, value: &T) -> Result<(), crate::Error>
where
T: serde::Serialize,
W: AsyncWrite + Unpin,
{
let mut buffer = serde_toml::to_string(value)
.with_kind(crate::ErrorKind::Serialization)?
.into_bytes();
buffer.extend_from_slice(b"\n");
writer.write_all(&buffer).await?;
Ok(())
}
pub async fn from_cbor_async_reader<T, R>(mut reader: R) -> Result<T, crate::Error>
where
T: for<'de> serde::Deserialize<'de>,
R: AsyncRead + Unpin,
{
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).await?;
serde_cbor::de::from_reader(buffer.as_slice())
.map_err(color_eyre::eyre::Error::from)
.with_kind(crate::ErrorKind::Deserialization)
}
pub async fn to_cbor_async_writer<T, W>(mut writer: W, value: &T) -> Result<(), crate::Error>
where
T: serde::Serialize,
W: AsyncWrite + Unpin,
{
let mut buffer = Vec::new();
serde_cbor::ser::into_writer(value, &mut buffer).with_kind(crate::ErrorKind::Serialization)?;
buffer.extend_from_slice(b"\n");
writer.write_all(&buffer).await?;
Ok(())
}
pub async fn from_json_async_reader<T, R>(mut reader: R) -> Result<T, crate::Error>
where
T: for<'de> serde::Deserialize<'de>,
R: AsyncRead + Unpin,
{
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).await?;
serde_json::from_slice(&buffer)
.map_err(color_eyre::eyre::Error::from)
.with_kind(crate::ErrorKind::Deserialization)
}
pub async fn to_json_async_writer<T, W>(mut writer: W, value: &T) -> Result<(), crate::Error>
where
T: serde::Serialize,
W: AsyncWrite + Unpin,
{
let buffer = serde_json::to_string(value).with_kind(crate::ErrorKind::Serialization)?;
writer.write_all(&buffer.as_bytes()).await?;
Ok(())
}
pub async fn to_json_pretty_async_writer<T, W>(mut writer: W, value: &T) -> Result<(), crate::Error>
where
T: serde::Serialize,
W: AsyncWrite + Unpin,
{
let mut buffer =
serde_json::to_string_pretty(value).with_kind(crate::ErrorKind::Serialization)?;
buffer.push_str("\n");
writer.write_all(&buffer.as_bytes()).await?;
Ok(())
}
pub async fn copy_and_shutdown<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
r: &mut R,
mut w: W,
) -> Result<(), std::io::Error> {
tokio::io::copy(r, &mut w).await?;
w.flush().await?;
w.shutdown().await?;
Ok(())
}
pub fn dir_size<'a, P: AsRef<Path> + 'a + Send + Sync>(
path: P,
ctr: Option<&'a Counter>,
) -> BoxFuture<'a, Result<u64, std::io::Error>> {
async move {
tokio_stream::wrappers::ReadDirStream::new(tokio::fs::read_dir(path.as_ref()).await?)
.try_fold(0, |acc, e| async move {
let m = e.metadata().await?;
Ok(acc
+ if m.is_file() {
if let Some(ctr) = ctr {
ctr.add(m.len());
}
m.len()
} else if m.is_dir() {
dir_size(e.path(), ctr).await?
} else {
0
})
})
.await
}
.boxed()
}
pub fn response_to_reader(response: reqwest::Response) -> impl AsyncRead + Unpin {
tokio_util::io::StreamReader::new(response.bytes_stream().map_err(|e| {
std::io::Error::new(
if e.is_connect() {
std::io::ErrorKind::ConnectionRefused
} else if e.is_timeout() {
std::io::ErrorKind::TimedOut
} else {
std::io::ErrorKind::Other
},
e,
)
}))
}
#[pin_project::pin_project]
pub struct BufferedWriteReader {
#[pin]
hdl: Fuse<NonDetachingJoinHandle<Result<(), std::io::Error>>>,
#[pin]
rdr: DuplexStream,
}
impl BufferedWriteReader {
pub fn new<
W: FnOnce(WriteHalf<DuplexStream>) -> Fut,
Fut: Future<Output = Result<(), std::io::Error>> + Send + Sync + 'static,
>(
write_fn: W,
max_buf_size: usize,
) -> Self {
let (w, rdr) = duplex(max_buf_size);
let (_, w) = tokio::io::split(w);
BufferedWriteReader {
hdl: NonDetachingJoinHandle::from(tokio::spawn(write_fn(w))).fuse(),
rdr,
}
}
}
impl AsyncRead for BufferedWriteReader {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let this = self.project();
let res = this.rdr.poll_read(cx, buf);
match this.hdl.poll(cx) {
Poll::Ready(Ok(Err(e))) => return Poll::Ready(Err(e)),
Poll::Ready(Err(e)) => {
return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)))
}
_ => res,
}
}
}
pub trait CursorExt {
fn pure_read(&mut self, buf: &mut ReadBuf<'_>);
}
impl<T: AsRef<[u8]>> CursorExt for Cursor<T> {
fn pure_read(&mut self, buf: &mut ReadBuf<'_>) {
let end = self.position() as usize
+ std::cmp::min(
buf.remaining(),
self.get_ref().as_ref().len() - self.position() as usize,
);
buf.put_slice(&self.get_ref().as_ref()[self.position() as usize..end]);
self.set_position(end as u64);
}
}
#[pin_project::pin_project]
#[derive(Debug)]
pub struct BackTrackingReader<T> {
#[pin]
reader: T,
buffer: Cursor<Vec<u8>>,
buffering: bool,
}
impl<T> BackTrackingReader<T> {
pub fn new(reader: T) -> Self {
Self {
reader,
buffer: Cursor::new(Vec::new()),
buffering: false,
}
}
pub fn start_buffering(&mut self) {
self.buffer.set_position(0);
self.buffer.get_mut().truncate(0);
self.buffering = true;
}
pub fn stop_buffering(&mut self) {
self.buffer.set_position(0);
self.buffer.get_mut().truncate(0);
self.buffering = false;
}
pub fn rewind(&mut self) {
self.buffering = false;
}
pub fn unwrap(self) -> T {
self.reader
}
}
impl<T: AsyncRead> AsyncRead for BackTrackingReader<T> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.project();
if *this.buffering {
let filled = buf.filled().len();
let res = this.reader.poll_read(cx, buf);
this.buffer
.get_mut()
.extend_from_slice(&buf.filled()[filled..]);
res
} else {
let mut ready = false;
if (this.buffer.position() as usize) < this.buffer.get_ref().len() {
this.buffer.pure_read(buf);
ready = true;
}
if buf.remaining() > 0 {
match this.reader.poll_read(cx, buf) {
Poll::Pending => {
if ready {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
a => a,
}
} else {
Poll::Ready(Ok(()))
}
}
}
}
impl<T: AsyncWrite> AsyncWrite for BackTrackingReader<T> {
fn is_write_vectored(&self) -> bool {
self.reader.is_write_vectored()
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
self.project().reader.poll_flush(cx)
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
self.project().reader.poll_shutdown(cx)
}
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
self.project().reader.poll_write(cx, buf)
}
fn poll_write_vectored(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
self.project().reader.poll_write_vectored(cx, bufs)
}
}
pub struct Counter {
atomic: AtomicU64,
ordering: std::sync::atomic::Ordering,
}
impl Counter {
pub fn new(init: u64, ordering: std::sync::atomic::Ordering) -> Self {
Self {
atomic: AtomicU64::new(init),
ordering,
}
}
pub fn load(&self) -> u64 {
self.atomic.load(self.ordering)
}
pub fn add(&self, value: u64) {
self.atomic.fetch_add(value, self.ordering);
}
}
#[pin_project::pin_project]
pub struct CountingReader<'a, R> {
ctr: &'a Counter,
#[pin]
rdr: R,
}
impl<'a, R> CountingReader<'a, R> {
pub fn new(rdr: R, ctr: &'a Counter) -> Self {
Self { ctr, rdr }
}
pub fn into_inner(self) -> R {
self.rdr
}
}
impl<'a, R: AsyncRead> AsyncRead for CountingReader<'a, R> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.project();
let start = buf.filled().len();
let res = this.rdr.poll_read(cx, buf);
let len = buf.filled().len() - start;
if len > 0 {
this.ctr.add(len as u64);
}
res
}
}
pub fn dir_copy<'a, P0: AsRef<Path> + 'a + Send + Sync, P1: AsRef<Path> + 'a + Send + Sync>(
src: P0,
dst: P1,
ctr: Option<&'a Counter>,
) -> BoxFuture<'a, Result<(), crate::Error>> {
async move {
let m = tokio::fs::metadata(&src).await?;
let dst_path = dst.as_ref();
tokio::fs::create_dir_all(&dst_path).await.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
format!("mkdir {}", dst_path.display()),
)
})?;
tokio::fs::set_permissions(&dst_path, m.permissions())
.await
.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
format!("chmod {}", dst_path.display()),
)
})?;
let tmp_dst_path = dst_path.to_owned();
tokio::task::spawn_blocking(move || {
nix::unistd::chown(
&tmp_dst_path,
Some(Uid::from_raw(m.uid())),
Some(Gid::from_raw(m.gid())),
)
})
.await
.with_kind(crate::ErrorKind::Unknown)?
.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
format!("chown {}", dst_path.display()),
)
})?;
tokio_stream::wrappers::ReadDirStream::new(tokio::fs::read_dir(src.as_ref()).await?)
.map_err(|e| crate::Error::new(e, crate::ErrorKind::Filesystem))
.try_for_each(|e| async move {
let m = e.metadata().await?;
let src_path = e.path();
let dst_path = dst_path.join(e.file_name());
if m.is_file() {
let mut dst_file = tokio::fs::File::create(&dst_path).await.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
format!("create {}", dst_path.display()),
)
})?;
let mut rdr = tokio::fs::File::open(&src_path).await.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
format!("open {}", src_path.display()),
)
})?;
if let Some(ctr) = ctr {
tokio::io::copy(&mut CountingReader::new(rdr, ctr), &mut dst_file).await
} else {
tokio::io::copy(&mut rdr, &mut dst_file).await
}
.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
format!("cp {} -> {}", src_path.display(), dst_path.display()),
)
})?;
dst_file.flush().await?;
dst_file.shutdown().await?;
dst_file.sync_all().await?;
drop(dst_file);
let tmp_dst_path = dst_path.clone();
tokio::task::spawn_blocking(move || {
nix::unistd::chown(
&tmp_dst_path,
Some(Uid::from_raw(m.uid())),
Some(Gid::from_raw(m.gid())),
)
})
.await
.with_kind(crate::ErrorKind::Unknown)?
.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
format!("chown {}", dst_path.display()),
)
})?;
} else if m.is_dir() {
dir_copy(src_path, dst_path, ctr).await?;
} else if m.file_type().is_symlink() {
tokio::fs::symlink(
tokio::fs::read_link(&src_path).await.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
format!("readlink {}", src_path.display()),
)
})?,
&dst_path,
)
.await
.with_ctx(|_| {
(
crate::ErrorKind::Filesystem,
format!("cp -P {} -> {}", src_path.display(), dst_path.display()),
)
})?;
// Do not set permissions (see https://unix.stackexchange.com/questions/87200/change-permissions-for-a-symbolic-link)
}
Ok(())
})
.await?;
Ok(())
}
.boxed()
}
#[pin_project::pin_project]
pub struct TimeoutStream<S: AsyncRead + AsyncWrite = TcpStream> {
timeout: Duration,
#[pin]
sleep: Sleep,
#[pin]
stream: S,
}
impl<S: AsyncRead + AsyncWrite> TimeoutStream<S> {
pub fn new(stream: S, timeout: Duration) -> Self {
Self {
timeout,
sleep: tokio::time::sleep(timeout),
stream,
}
}
}
impl<S: AsyncRead + AsyncWrite> AsyncRead for TimeoutStream<S> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let mut this = self.project();
if let std::task::Poll::Ready(_) = this.sleep.as_mut().poll(cx) {
return std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"timed out",
)));
}
let res = this.stream.poll_read(cx, buf);
if res.is_ready() {
this.sleep.reset(Instant::now() + *this.timeout);
}
res
}
}
impl<S: AsyncRead + AsyncWrite> AsyncWrite for TimeoutStream<S> {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let this = self.project();
let res = this.stream.poll_write(cx, buf);
if res.is_ready() {
this.sleep.reset(Instant::now() + *this.timeout);
}
res
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
let res = this.stream.poll_flush(cx);
if res.is_ready() {
this.sleep.reset(Instant::now() + *this.timeout);
}
res
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
let this = self.project();
let res = this.stream.poll_shutdown(cx);
if res.is_ready() {
this.sleep.reset(Instant::now() + *this.timeout);
}
res
}
}

View File

@@ -0,0 +1,52 @@
use tracing::Subscriber;
use tracing_subscriber::util::SubscriberInitExt;
#[derive(Clone)]
pub struct EmbassyLogger {}
impl EmbassyLogger {
fn base_subscriber() -> impl Subscriber {
use tracing_error::ErrorLayer;
use tracing_subscriber::prelude::*;
use tracing_subscriber::{fmt, EnvFilter};
let filter_layer = EnvFilter::builder()
.with_default_directive(
format!("{}=info", std::module_path!().split("::").next().unwrap())
.parse()
.unwrap(),
)
.from_env_lossy();
#[cfg(feature = "unstable")]
let filter_layer = filter_layer
.add_directive("tokio=trace".parse().unwrap())
.add_directive("runtime=trace".parse().unwrap());
let fmt_layer = fmt::layer().with_target(true);
let sub = tracing_subscriber::registry()
.with(filter_layer)
.with(fmt_layer)
.with(ErrorLayer::default());
#[cfg(feature = "unstable")]
let sub = sub.with(console_subscriber::spawn());
sub
}
pub fn init() -> Self {
Self::base_subscriber().init();
color_eyre::install().unwrap_or_else(|_| tracing::warn!("tracing too many times"));
EmbassyLogger {}
}
}
#[tokio::test]
pub async fn order_level() {
assert!(tracing::Level::WARN > tracing::Level::ERROR)
}
#[test]
pub fn module() {
println!("{}", module_path!())
}

View File

@@ -0,0 +1,63 @@
use models::{Error, ResultExt};
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use crate::util::Invoke;
const KNOWN_CLASSES: &[&str] = &["processor", "display"];
#[derive(Debug, Deserialize, Serialize)]
#[serde(tag = "class")]
#[serde(rename_all = "kebab-case")]
pub enum LshwDevice {
Processor(LshwProcessor),
Display(LshwDisplay),
}
impl LshwDevice {
pub fn class(&self) -> &'static str {
match self {
Self::Processor(_) => "processor",
Self::Display(_) => "display",
}
}
pub fn product(&self) -> &str {
match self {
Self::Processor(hw) => hw.product.as_str(),
Self::Display(hw) => hw.product.as_str(),
}
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct LshwProcessor {
pub product: String,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct LshwDisplay {
pub product: String,
}
pub async fn lshw() -> Result<Vec<LshwDevice>, Error> {
let mut cmd = Command::new("lshw");
cmd.arg("-json");
for class in KNOWN_CLASSES {
cmd.arg("-class").arg(*class);
}
Ok(
serde_json::from_slice::<Vec<serde_json::Value>>(
&cmd.invoke(crate::ErrorKind::Lshw).await?,
)
.with_kind(crate::ErrorKind::Deserialization)?
.into_iter()
.filter_map(|v| match serde_json::from_value(v) {
Ok(a) => Some(a),
Err(e) => {
tracing::error!("Failed to parse lshw output: {e}");
tracing::debug!("{e:?}");
None
}
})
.collect(),
)
}

View File

@@ -0,0 +1,468 @@
use std::collections::BTreeMap;
use std::future::Future;
use std::marker::PhantomData;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::process::Stdio;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use async_trait::async_trait;
use clap::ArgMatches;
use color_eyre::eyre::{self, eyre};
use fd_lock_rs::FdLock;
use helpers::canonicalize;
pub use helpers::NonDetachingJoinHandle;
use lazy_static::lazy_static;
pub use models::Version;
use pin_project::pin_project;
use sha2::Digest;
use tokio::fs::File;
use tokio::sync::{Mutex, OwnedMutexGuard, RwLock};
use tracing::instrument;
use crate::shutdown::Shutdown;
use crate::{Error, ErrorKind, ResultExt as _};
pub mod config;
pub mod cpupower;
pub mod crypto;
pub mod docker;
pub mod http_reader;
pub mod io;
pub mod logger;
pub mod lshw;
pub mod serde;
#[derive(Clone, Copy, Debug, ::serde::Deserialize, ::serde::Serialize)]
pub enum Never {}
impl Never {}
impl Never {
pub fn absurd<T>(self) -> T {
match self {}
}
}
impl std::fmt::Display for Never {
fn fmt(&self, _f: &mut std::fmt::Formatter) -> std::fmt::Result {
self.absurd()
}
}
impl std::error::Error for Never {}
#[async_trait::async_trait]
pub trait Invoke<'a> {
type Extended<'ext>
where
Self: 'ext,
'ext: 'a;
fn timeout<'ext: 'a>(&'ext mut self, timeout: Option<Duration>) -> Self::Extended<'ext>;
fn input<'ext: 'a, Input: tokio::io::AsyncRead + Unpin + Send>(
&'ext mut self,
input: Option<&'ext mut Input>,
) -> Self::Extended<'ext>;
async fn invoke(&mut self, error_kind: crate::ErrorKind) -> Result<Vec<u8>, Error>;
}
pub struct ExtendedCommand<'a> {
cmd: &'a mut tokio::process::Command,
timeout: Option<Duration>,
input: Option<&'a mut (dyn tokio::io::AsyncRead + Unpin + Send)>,
}
impl<'a> std::ops::Deref for ExtendedCommand<'a> {
type Target = tokio::process::Command;
fn deref(&self) -> &Self::Target {
&*self.cmd
}
}
impl<'a> std::ops::DerefMut for ExtendedCommand<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.cmd
}
}
#[async_trait::async_trait]
impl<'a> Invoke<'a> for tokio::process::Command {
type Extended<'ext> = ExtendedCommand<'ext>
where
Self: 'ext,
'ext: 'a;
fn timeout<'ext: 'a>(&'ext mut self, timeout: Option<Duration>) -> Self::Extended<'ext> {
ExtendedCommand {
cmd: self,
timeout,
input: None,
}
}
fn input<'ext: 'a, Input: tokio::io::AsyncRead + Unpin + Send>(
&'ext mut self,
input: Option<&'ext mut Input>,
) -> Self::Extended<'ext> {
ExtendedCommand {
cmd: self,
timeout: None,
input: if let Some(input) = input {
Some(&mut *input)
} else {
None
},
}
}
async fn invoke(&mut self, error_kind: crate::ErrorKind) -> Result<Vec<u8>, Error> {
ExtendedCommand {
cmd: self,
timeout: None,
input: None,
}
.invoke(error_kind)
.await
}
}
#[async_trait::async_trait]
impl<'a> Invoke<'a> for ExtendedCommand<'a> {
type Extended<'ext> = &'ext mut ExtendedCommand<'ext>
where
Self: 'ext,
'ext: 'a;
fn timeout<'ext: 'a>(&'ext mut self, timeout: Option<Duration>) -> Self::Extended<'ext> {
self.timeout = timeout;
self
}
fn input<'ext: 'a, Input: tokio::io::AsyncRead + Unpin + Send>(
&'ext mut self,
input: Option<&'ext mut Input>,
) -> Self::Extended<'ext> {
self.input = if let Some(input) = input {
Some(&mut *input)
} else {
None
};
self
}
async fn invoke(&mut self, error_kind: crate::ErrorKind) -> Result<Vec<u8>, Error> {
self.cmd.kill_on_drop(true);
if self.input.is_some() {
self.cmd.stdin(Stdio::piped());
}
self.cmd.stdout(Stdio::piped());
self.cmd.stderr(Stdio::piped());
let mut child = self.cmd.spawn()?;
if let (Some(mut stdin), Some(input)) = (child.stdin.take(), self.input.take()) {
use tokio::io::AsyncWriteExt;
tokio::io::copy(input, &mut stdin).await?;
stdin.flush().await?;
stdin.shutdown().await?;
drop(stdin);
}
let res = match self.timeout {
None => child.wait_with_output().await?,
Some(t) => tokio::time::timeout(t, child.wait_with_output())
.await
.with_kind(ErrorKind::Timeout)??,
};
crate::ensure_code!(
res.status.success(),
error_kind,
"{}",
Some(&res.stderr)
.filter(|a| !a.is_empty())
.or(Some(&res.stdout))
.filter(|a| !a.is_empty())
.and_then(|a| std::str::from_utf8(a).ok())
.unwrap_or(&format!("Unknown Error ({})", res.status))
);
Ok(res.stdout)
}
}
pub trait Apply: Sized {
fn apply<O, F: FnOnce(Self) -> O>(self, func: F) -> O {
func(self)
}
}
pub trait ApplyRef {
fn apply_ref<O, F: FnOnce(&Self) -> O>(&self, func: F) -> O {
func(&self)
}
fn apply_mut<O, F: FnOnce(&mut Self) -> O>(&mut self, func: F) -> O {
func(self)
}
}
impl<T> Apply for T {}
impl<T> ApplyRef for T {}
pub async fn daemon<F: FnMut() -> Fut, Fut: Future<Output = ()> + Send + 'static>(
mut f: F,
cooldown: std::time::Duration,
mut shutdown: tokio::sync::broadcast::Receiver<Option<Shutdown>>,
) -> Result<(), eyre::Error> {
loop {
tokio::select! {
_ = shutdown.recv() => return Ok(()),
_ = tokio::time::sleep(cooldown) => (),
}
match tokio::spawn(f()).await {
Err(e) if e.is_panic() => return Err(eyre!("daemon panicked!")),
_ => (),
}
}
}
pub trait SOption<T> {}
pub struct SSome<T>(T);
impl<T> SSome<T> {
pub fn into(self) -> T {
self.0
}
}
impl<T> From<T> for SSome<T> {
fn from(t: T) -> Self {
SSome(t)
}
}
impl<T> SOption<T> for SSome<T> {}
pub struct SNone<T>(PhantomData<T>);
impl<T> SNone<T> {
pub fn new() -> Self {
SNone(PhantomData)
}
}
impl<T> SOption<T> for SNone<T> {}
#[async_trait]
pub trait AsyncFileExt: Sized {
async fn maybe_open<P: AsRef<Path> + Send + Sync>(path: P) -> std::io::Result<Option<Self>>;
async fn delete<P: AsRef<Path> + Send + Sync>(path: P) -> std::io::Result<()>;
}
#[async_trait]
impl AsyncFileExt for File {
async fn maybe_open<P: AsRef<Path> + Send + Sync>(path: P) -> std::io::Result<Option<Self>> {
match File::open(path).await {
Ok(f) => Ok(Some(f)),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(e) => Err(e),
}
}
async fn delete<P: AsRef<Path> + Send + Sync>(path: P) -> std::io::Result<()> {
if let Ok(m) = tokio::fs::metadata(path.as_ref()).await {
if m.is_dir() {
tokio::fs::remove_dir_all(path).await
} else {
tokio::fs::remove_file(path).await
}
} else {
Ok(())
}
}
}
pub struct FmtWriter<W: std::fmt::Write>(W);
impl<W: std::fmt::Write> std::io::Write for FmtWriter<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0
.write_str(
std::str::from_utf8(buf)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?,
)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
pub fn display_none<T>(_: T, _: &ArgMatches) {}
pub struct Container<T>(RwLock<Option<T>>);
impl<T> Container<T> {
pub fn new(value: Option<T>) -> Self {
Container(RwLock::new(value))
}
pub async fn set(&self, value: T) -> Option<T> {
std::mem::replace(&mut *self.0.write().await, Some(value))
}
pub async fn take(&self) -> Option<T> {
self.0.write().await.take()
}
pub async fn is_empty(&self) -> bool {
self.0.read().await.is_none()
}
pub async fn drop(&self) {
*self.0.write().await = None;
}
}
#[pin_project]
pub struct HashWriter<H: Digest, W: tokio::io::AsyncWrite> {
hasher: H,
#[pin]
writer: W,
}
impl<H: Digest, W: tokio::io::AsyncWrite> HashWriter<H, W> {
pub fn new(hasher: H, writer: W) -> Self {
HashWriter { hasher, writer }
}
pub fn finish(self) -> (H, W) {
(self.hasher, self.writer)
}
pub fn inner(&self) -> &W {
&self.writer
}
pub fn inner_mut(&mut self) -> &mut W {
&mut self.writer
}
}
impl<H: Digest, W: tokio::io::AsyncWrite> tokio::io::AsyncWrite for HashWriter<H, W> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let this = self.project();
let written = tokio::io::AsyncWrite::poll_write(this.writer, cx, buf);
match written {
// only update the hasher once
Poll::Ready(res) => {
if let Ok(n) = res {
this.hasher.update(&buf[..n]);
}
Poll::Ready(res)
}
Poll::Pending => Poll::Pending,
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>> {
self.project().writer.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<std::io::Result<()>> {
self.project().writer.poll_shutdown(cx)
}
}
pub trait IntoDoubleEndedIterator<T>: IntoIterator<Item = T> {
type IntoIter: Iterator<Item = T> + DoubleEndedIterator;
fn into_iter(self) -> <Self as IntoDoubleEndedIterator<T>>::IntoIter;
}
impl<T, U> IntoDoubleEndedIterator<U> for T
where
T: IntoIterator<Item = U>,
<T as IntoIterator>::IntoIter: DoubleEndedIterator,
{
type IntoIter = <T as IntoIterator>::IntoIter;
fn into_iter(self) -> <Self as IntoDoubleEndedIterator<U>>::IntoIter {
IntoIterator::into_iter(self)
}
}
pub struct GeneralBoxedGuard(Option<Box<dyn FnOnce() + Send + Sync>>);
impl GeneralBoxedGuard {
pub fn new(f: impl FnOnce() + 'static + Send + Sync) -> Self {
GeneralBoxedGuard(Some(Box::new(f)))
}
pub fn drop(mut self) {
self.0.take().unwrap()()
}
pub fn drop_without_action(mut self) {
self.0 = None;
}
}
impl Drop for GeneralBoxedGuard {
fn drop(&mut self) {
if let Some(destroy) = self.0.take() {
destroy();
}
}
}
pub struct GeneralGuard<F: FnOnce() -> T, T = ()>(Option<F>);
impl<F: FnOnce() -> T, T> GeneralGuard<F, T> {
pub fn new(f: F) -> Self {
GeneralGuard(Some(f))
}
pub fn drop(mut self) -> T {
self.0.take().unwrap()()
}
pub fn drop_without_action(mut self) {
self.0 = None;
}
}
impl<F: FnOnce() -> T, T> Drop for GeneralGuard<F, T> {
fn drop(&mut self) {
if let Some(destroy) = self.0.take() {
destroy();
}
}
}
pub struct FileLock(OwnedMutexGuard<()>, Option<FdLock<File>>);
impl Drop for FileLock {
fn drop(&mut self) {
if let Some(fd_lock) = self.1.take() {
tokio::task::spawn_blocking(|| fd_lock.unlock(true).map_err(|(_, e)| e).unwrap());
}
}
}
impl FileLock {
#[instrument(skip_all)]
pub async fn new(path: impl AsRef<Path> + Send + Sync, blocking: bool) -> Result<Self, Error> {
lazy_static! {
static ref INTERNAL_LOCKS: Mutex<BTreeMap<PathBuf, Arc<Mutex<()>>>> =
Mutex::new(BTreeMap::new());
}
let path = canonicalize(path.as_ref(), true)
.await
.with_kind(ErrorKind::Filesystem)?;
let mut internal_locks = INTERNAL_LOCKS.lock().await;
if !internal_locks.contains_key(&path) {
internal_locks.insert(path.clone(), Arc::new(Mutex::new(())));
}
let tex = internal_locks.get(&path).unwrap().clone();
drop(internal_locks);
let tex_guard = if blocking {
tex.lock_owned().await
} else {
tex.try_lock_owned()
.with_kind(crate::ErrorKind::Filesystem)?
};
let parent = path.parent().unwrap_or(Path::new("/"));
if tokio::fs::metadata(parent).await.is_err() {
tokio::fs::create_dir_all(parent)
.await
.with_ctx(|_| (crate::ErrorKind::Filesystem, parent.display().to_string()))?;
}
let f = File::create(&path)
.await
.with_ctx(|_| (crate::ErrorKind::Filesystem, path.display().to_string()))?;
let file_guard = tokio::task::spawn_blocking(move || {
fd_lock_rs::FdLock::lock(f, fd_lock_rs::LockType::Exclusive, blocking)
})
.await
.with_kind(crate::ErrorKind::Unknown)?
.with_kind(crate::ErrorKind::Filesystem)?;
Ok(FileLock(tex_guard, Some(file_guard)))
}
pub async fn unlock(mut self) -> Result<(), Error> {
if let Some(fd_lock) = self.1.take() {
tokio::task::spawn_blocking(|| fd_lock.unlock(true).map_err(|(_, e)| e))
.await
.with_kind(crate::ErrorKind::Unknown)?
.with_kind(crate::ErrorKind::Filesystem)?;
}
Ok(())
}
}
pub fn assure_send<T: Send>(x: T) -> T {
x
}

View File

@@ -0,0 +1,845 @@
use std::marker::PhantomData;
use std::ops::Deref;
use std::process::exit;
use std::str::FromStr;
use clap::ArgMatches;
use color_eyre::eyre::eyre;
use serde::ser::{SerializeMap, SerializeSeq};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::Value;
use super::IntoDoubleEndedIterator;
use crate::{Error, ResultExt};
pub fn deserialize_from_str<
'de,
D: serde::de::Deserializer<'de>,
T: FromStr<Err = E>,
E: std::fmt::Display,
>(
deserializer: D,
) -> std::result::Result<T, D::Error> {
struct Visitor<T: FromStr<Err = E>, E>(std::marker::PhantomData<T>);
impl<'de, T: FromStr<Err = Err>, Err: std::fmt::Display> serde::de::Visitor<'de>
for Visitor<T, Err>
{
type Value = T;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "a parsable string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
v.parse().map_err(|e| serde::de::Error::custom(e))
}
}
deserializer.deserialize_str(Visitor(std::marker::PhantomData))
}
pub fn deserialize_from_str_opt<
'de,
D: serde::de::Deserializer<'de>,
T: FromStr<Err = E>,
E: std::fmt::Display,
>(
deserializer: D,
) -> std::result::Result<Option<T>, D::Error> {
struct Visitor<T: FromStr<Err = E>, E>(std::marker::PhantomData<T>);
impl<'de, T: FromStr<Err = Err>, Err: std::fmt::Display> serde::de::Visitor<'de>
for Visitor<T, Err>
{
type Value = Option<T>;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "a parsable string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
v.parse().map(Some).map_err(|e| serde::de::Error::custom(e))
}
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::de::Deserializer<'de>,
{
deserializer.deserialize_str(Visitor(std::marker::PhantomData))
}
fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(None)
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(None)
}
}
deserializer.deserialize_any(Visitor(std::marker::PhantomData))
}
pub fn serialize_display<T: std::fmt::Display, S: Serializer>(
t: &T,
serializer: S,
) -> Result<S::Ok, S::Error> {
String::serialize(&t.to_string(), serializer)
}
pub fn serialize_display_opt<T: std::fmt::Display, S: Serializer>(
t: &Option<T>,
serializer: S,
) -> Result<S::Ok, S::Error> {
Option::<String>::serialize(&t.as_ref().map(|t| t.to_string()), serializer)
}
pub mod ed25519_pubkey {
use ed25519_dalek::VerifyingKey;
use serde::de::{Error, Unexpected, Visitor};
use serde::{Deserializer, Serializer};
pub fn serialize<S: Serializer>(
pubkey: &VerifyingKey,
serializer: S,
) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&base32::encode(
base32::Alphabet::RFC4648 { padding: true },
pubkey.as_bytes(),
))
}
pub fn deserialize<'de, D: Deserializer<'de>>(
deserializer: D,
) -> Result<VerifyingKey, D::Error> {
struct PubkeyVisitor;
impl<'de> Visitor<'de> for PubkeyVisitor {
type Value = ed25519_dalek::VerifyingKey;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "an RFC4648 encoded string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
VerifyingKey::from_bytes(
&<[u8; 32]>::try_from(
base32::decode(base32::Alphabet::RFC4648 { padding: true }, v).ok_or(
Error::invalid_value(Unexpected::Str(v), &"an RFC4648 encoded string"),
)?,
)
.map_err(|e| Error::invalid_length(e.len(), &"32 bytes"))?,
)
.map_err(Error::custom)
}
}
deserializer.deserialize_str(PubkeyVisitor)
}
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum ValuePrimative {
Null,
Boolean(bool),
String(String),
Number(serde_json::Number),
}
impl<'de> serde::de::Deserialize<'de> for ValuePrimative {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
struct Visitor;
impl<'de> serde::de::Visitor<'de> for Visitor {
type Value = ValuePrimative;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a JSON primative value")
}
fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::Null)
}
fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::Null)
}
fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::Boolean(v))
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::String(v.to_owned()))
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::String(v))
}
fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::Number(
serde_json::Number::from_f64(v as f64).ok_or_else(|| {
serde::de::Error::invalid_value(
serde::de::Unexpected::Float(v as f64),
&"a finite number",
)
})?,
))
}
fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::Number(
serde_json::Number::from_f64(v).ok_or_else(|| {
serde::de::Error::invalid_value(
serde::de::Unexpected::Float(v),
&"a finite number",
)
})?,
))
}
fn visit_u8<E>(self, v: u8) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::Number(v.into()))
}
fn visit_u16<E>(self, v: u16) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::Number(v.into()))
}
fn visit_u32<E>(self, v: u32) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::Number(v.into()))
}
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::Number(v.into()))
}
fn visit_i8<E>(self, v: i8) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::Number(v.into()))
}
fn visit_i16<E>(self, v: i16) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::Number(v.into()))
}
fn visit_i32<E>(self, v: i32) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::Number(v.into()))
}
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(ValuePrimative::Number(v.into()))
}
}
deserializer.deserialize_any(Visitor)
}
}
#[derive(Clone, Copy, Debug, Deserialize, Serialize)]
#[serde(rename_all = "kebab-case")]
pub enum IoFormat {
Json,
JsonPretty,
Yaml,
Cbor,
Toml,
TomlPretty,
}
impl Default for IoFormat {
fn default() -> Self {
IoFormat::JsonPretty
}
}
impl std::fmt::Display for IoFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use IoFormat::*;
match self {
Json => write!(f, "JSON"),
JsonPretty => write!(f, "JSON (pretty)"),
Yaml => write!(f, "YAML"),
Cbor => write!(f, "CBOR"),
Toml => write!(f, "TOML"),
TomlPretty => write!(f, "TOML (pretty)"),
}
}
}
impl std::str::FromStr for IoFormat {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
serde_json::from_value(Value::String(s.to_owned()))
.with_kind(crate::ErrorKind::Deserialization)
}
}
impl IoFormat {
pub fn to_writer<W: std::io::Write, T: Serialize>(
&self,
mut writer: W,
value: &T,
) -> Result<(), Error> {
match self {
IoFormat::Json => {
serde_json::to_writer(writer, value).with_kind(crate::ErrorKind::Serialization)
}
IoFormat::JsonPretty => serde_json::to_writer_pretty(writer, value)
.with_kind(crate::ErrorKind::Serialization),
IoFormat::Yaml => {
serde_yaml::to_writer(writer, value).with_kind(crate::ErrorKind::Serialization)
}
IoFormat::Cbor => serde_cbor::ser::into_writer(value, writer)
.with_kind(crate::ErrorKind::Serialization),
IoFormat::Toml => writer
.write_all(
serde_toml::to_string(
&serde_toml::Value::try_from(value)
.with_kind(crate::ErrorKind::Serialization)?,
)
.with_kind(crate::ErrorKind::Serialization)?
.as_bytes(),
)
.with_kind(crate::ErrorKind::Serialization),
IoFormat::TomlPretty => writer
.write_all(
serde_toml::to_string_pretty(
&serde_toml::Value::try_from(value)
.with_kind(crate::ErrorKind::Serialization)?,
)
.with_kind(crate::ErrorKind::Serialization)?
.as_bytes(),
)
.with_kind(crate::ErrorKind::Serialization),
}
}
pub fn to_vec<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, Error> {
match self {
IoFormat::Json => serde_json::to_vec(value).with_kind(crate::ErrorKind::Serialization),
IoFormat::JsonPretty => {
serde_json::to_vec_pretty(value).with_kind(crate::ErrorKind::Serialization)
}
IoFormat::Yaml => serde_yaml::to_string(value)
.with_kind(crate::ErrorKind::Serialization)
.map(|s| s.into_bytes()),
IoFormat::Cbor => {
let mut res = Vec::new();
serde_cbor::ser::into_writer(value, &mut res)
.with_kind(crate::ErrorKind::Serialization)?;
Ok(res)
}
IoFormat::Toml => serde_toml::to_string(
&serde_toml::Value::try_from(value).with_kind(crate::ErrorKind::Serialization)?,
)
.with_kind(crate::ErrorKind::Serialization)
.map(|s| s.into_bytes()),
IoFormat::TomlPretty => serde_toml::to_string_pretty(
&serde_toml::Value::try_from(value).with_kind(crate::ErrorKind::Serialization)?,
)
.map(|s| s.into_bytes())
.with_kind(crate::ErrorKind::Serialization),
}
}
/// BLOCKING
pub fn from_reader<R: std::io::Read, T: for<'de> Deserialize<'de>>(
&self,
mut reader: R,
) -> Result<T, Error> {
match self {
IoFormat::Json | IoFormat::JsonPretty => {
serde_json::from_reader(reader).with_kind(crate::ErrorKind::Deserialization)
}
IoFormat::Yaml => {
serde_yaml::from_reader(reader).with_kind(crate::ErrorKind::Deserialization)
}
IoFormat::Cbor => {
serde_cbor::de::from_reader(reader).with_kind(crate::ErrorKind::Deserialization)
}
IoFormat::Toml | IoFormat::TomlPretty => {
let mut s = String::new();
reader
.read_to_string(&mut s)
.with_kind(crate::ErrorKind::Deserialization)?;
serde_toml::from_str(&s).with_kind(crate::ErrorKind::Deserialization)
}
}
}
pub async fn from_async_reader<
R: tokio::io::AsyncRead + Unpin,
T: for<'de> Deserialize<'de>,
>(
&self,
reader: R,
) -> Result<T, Error> {
use crate::util::io::*;
match self {
IoFormat::Json | IoFormat::JsonPretty => from_json_async_reader(reader).await,
IoFormat::Yaml => from_yaml_async_reader(reader).await,
IoFormat::Cbor => from_cbor_async_reader(reader).await,
IoFormat::Toml | IoFormat::TomlPretty => from_toml_async_reader(reader).await,
}
}
pub fn from_slice<T: for<'de> Deserialize<'de>>(&self, slice: &[u8]) -> Result<T, Error> {
match self {
IoFormat::Json | IoFormat::JsonPretty => {
serde_json::from_slice(slice).with_kind(crate::ErrorKind::Deserialization)
}
IoFormat::Yaml => {
serde_yaml::from_slice(slice).with_kind(crate::ErrorKind::Deserialization)
}
IoFormat::Cbor => {
serde_cbor::de::from_reader(slice).with_kind(crate::ErrorKind::Deserialization)
}
IoFormat::Toml | IoFormat::TomlPretty => {
serde_toml::from_str(std::str::from_utf8(slice)?)
.with_kind(crate::ErrorKind::Deserialization)
}
}
}
}
pub fn display_serializable<T: Serialize>(t: T, matches: &ArgMatches) {
let format = match matches.value_of("format").map(|f| f.parse()) {
Some(Ok(f)) => f,
Some(Err(_)) => {
eprintln!("unrecognized formatter");
exit(1)
}
None => IoFormat::default(),
};
format
.to_writer(std::io::stdout(), &t)
.expect("Error serializing result to stdout")
}
pub fn parse_stdin_deserializable<T: for<'de> Deserialize<'de>>(
stdin: &mut std::io::Stdin,
matches: &ArgMatches,
) -> Result<T, Error> {
let format = match matches.value_of("format").map(|f| f.parse()) {
Some(Ok(f)) => f,
Some(Err(_)) => {
eprintln!("unrecognized formatter");
exit(1)
}
None => IoFormat::default(),
};
format.from_reader(stdin)
}
#[derive(Debug, Clone, Copy)]
pub struct Duration(std::time::Duration);
impl Deref for Duration {
type Target = std::time::Duration;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<std::time::Duration> for Duration {
fn from(t: std::time::Duration) -> Self {
Duration(t)
}
}
impl std::str::FromStr for Duration {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let units_idx = s.find(|c: char| c.is_alphabetic()).ok_or_else(|| {
Error::new(
eyre!("Must specify units for duration"),
crate::ErrorKind::Deserialization,
)
})?;
let (num, units) = s.split_at(units_idx);
use std::time::Duration;
Ok(Duration(match units {
"d" if num.contains(".") => Duration::from_secs_f64(num.parse::<f64>()? * 86_400_f64),
"d" => Duration::from_secs(num.parse::<u64>()? * 86_400),
"h" if num.contains(".") => Duration::from_secs_f64(num.parse::<f64>()? * 3_600_f64),
"h" => Duration::from_secs(num.parse::<u64>()? * 3_600),
"m" if num.contains(".") => Duration::from_secs_f64(num.parse::<f64>()? * 60_f64),
"m" => Duration::from_secs(num.parse::<u64>()? * 60),
"s" if num.contains(".") => Duration::from_secs_f64(num.parse()?),
"s" => Duration::from_secs(num.parse()?),
"ms" if num.contains(".") => Duration::from_secs_f64(num.parse::<f64>()? / 1_000_f64),
"ms" => {
let millis: u128 = num.parse()?;
Duration::new((millis / 1_000) as u64, (millis % 1_000) as u32)
}
"us" | "µs" if num.contains(".") => {
Duration::from_secs_f64(num.parse::<f64>()? / 1_000_000_f64)
}
"us" | "µs" => {
let micros: u128 = num.parse()?;
Duration::new((micros / 1_000_000) as u64, (micros % 1_000_000) as u32)
}
"ns" if num.contains(".") => {
Duration::from_secs_f64(num.parse::<f64>()? / 1_000_000_000_f64)
}
"ns" => {
let nanos: u128 = num.parse()?;
Duration::new(
(nanos / 1_000_000_000) as u64,
(nanos % 1_000_000_000) as u32,
)
}
_ => {
return Err(Error::new(
eyre!("Invalid units for duration"),
crate::ErrorKind::Deserialization,
))
}
}))
}
}
impl std::fmt::Display for Duration {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let nanos = self.as_nanos();
match () {
_ if nanos % 86_400_000_000_000 == 0 => write!(f, "{}d", nanos / 86_400_000_000_000),
_ if nanos % 3_600_000_000_000 == 0 => write!(f, "{}h", nanos / 3_600_000_000_000),
_ if nanos % 60_000_000_000 == 0 => write!(f, "{}m", nanos / 60_000_000_000),
_ if nanos % 1_000_000_000 == 0 => write!(f, "{}s", nanos / 1_000_000_000),
_ if nanos % 1_000_000 == 0 => write!(f, "{}ms", nanos / 1_000_000),
_ if nanos % 1_000 == 0 => write!(f, "{}µs", nanos / 1_000),
_ => write!(f, "{}ns", nanos),
}
}
}
impl<'de> Deserialize<'de> for Duration {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserialize_from_str(deserializer)
}
}
impl Serialize for Duration {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_display(self, serializer)
}
}
pub fn deserialize_number_permissive<
'de,
D: serde::de::Deserializer<'de>,
T: FromStr<Err = E> + num::cast::FromPrimitive,
E: std::fmt::Display,
>(
deserializer: D,
) -> std::result::Result<T, D::Error> {
struct Visitor<T: FromStr<Err = E> + num::cast::FromPrimitive, E>(std::marker::PhantomData<T>);
impl<'de, T: FromStr<Err = Err> + num::cast::FromPrimitive, Err: std::fmt::Display>
serde::de::Visitor<'de> for Visitor<T, Err>
{
type Value = T;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "a parsable string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
v.parse().map_err(|e| serde::de::Error::custom(e))
}
fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
T::from_f64(v).ok_or_else(|| {
serde::de::Error::custom(format!(
"{} cannot be represented by the requested type",
v
))
})
}
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
T::from_u64(v).ok_or_else(|| {
serde::de::Error::custom(format!(
"{} cannot be represented by the requested type",
v
))
})
}
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
T::from_i64(v).ok_or_else(|| {
serde::de::Error::custom(format!(
"{} cannot be represented by the requested type",
v
))
})
}
}
deserializer.deserialize_str(Visitor(std::marker::PhantomData))
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Port(pub u16);
impl<'de> Deserialize<'de> for Port {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
//TODO: if number, be permissive
deserialize_number_permissive(deserializer).map(Port)
}
}
impl Serialize for Port {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_display(&self.0, serializer)
}
}
#[derive(Debug, Clone)]
pub struct Reversible<T, Container = Vec<T>>
where
for<'a> &'a Container: IntoDoubleEndedIterator<&'a T>,
{
reversed: bool,
data: Container,
phantom: PhantomData<T>,
}
impl<T, Container> Reversible<T, Container>
where
for<'a> &'a Container: IntoDoubleEndedIterator<&'a T>,
{
pub fn new(data: Container) -> Self {
Reversible {
reversed: false,
data,
phantom: PhantomData,
}
}
pub fn reverse(&mut self) {
self.reversed = !self.reversed
}
pub fn iter(
&self,
) -> itertools::Either<
<&Container as IntoDoubleEndedIterator<&T>>::IntoIter,
std::iter::Rev<<&Container as IntoDoubleEndedIterator<&T>>::IntoIter>,
> {
let iter = IntoDoubleEndedIterator::into_iter(&self.data);
if self.reversed {
itertools::Either::Right(iter.rev())
} else {
itertools::Either::Left(iter)
}
}
}
impl<T, Container> Serialize for Reversible<T, Container>
where
for<'a> &'a Container: IntoDoubleEndedIterator<&'a T>,
T: Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let iter = IntoDoubleEndedIterator::into_iter(&self.data);
let mut seq_ser = serializer.serialize_seq(match iter.size_hint() {
(lower, Some(upper)) if lower == upper => Some(upper),
_ => None,
})?;
if self.reversed {
for elem in iter.rev() {
seq_ser.serialize_element(elem)?;
}
} else {
for elem in iter {
seq_ser.serialize_element(elem)?;
}
}
seq_ser.end()
}
}
impl<'de, T, Container> Deserialize<'de> for Reversible<T, Container>
where
for<'a> &'a Container: IntoDoubleEndedIterator<&'a T>,
Container: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Ok(Reversible::new(Deserialize::deserialize(deserializer)?))
}
fn deserialize_in_place<D>(deserializer: D, place: &mut Self) -> Result<(), D::Error>
where
D: Deserializer<'de>,
{
Deserialize::deserialize_in_place(deserializer, &mut place.data)
}
}
pub struct KeyVal<K, V> {
pub key: K,
pub value: V,
}
impl<K: Serialize, V: Serialize> Serialize for KeyVal<K, V> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map = serializer.serialize_map(Some(1))?;
map.serialize_entry(&self.key, &self.value)?;
map.end()
}
}
impl<'de, K: Deserialize<'de>, V: Deserialize<'de>> Deserialize<'de> for KeyVal<K, V> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor<K, V>(PhantomData<(K, V)>);
impl<'de, K: Deserialize<'de>, V: Deserialize<'de>> serde::de::Visitor<'de> for Visitor<K, V> {
type Value = KeyVal<K, V>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "A map with a single element")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let (key, value) = map
.next_entry()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &"1"))?;
Ok(KeyVal { key, value })
}
}
deserializer.deserialize_map(Visitor(PhantomData))
}
}
pub struct Base32<T>(pub T);
impl<'de, T: TryFrom<Vec<u8>>> Deserialize<'de> for Base32<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
base32::decode(base32::Alphabet::RFC4648 { padding: true }, &s)
.ok_or_else(|| {
serde::de::Error::invalid_value(
serde::de::Unexpected::Str(&s),
&"a valid base32 string",
)
})?
.try_into()
.map_err(|_| serde::de::Error::custom("invalid length"))
.map(Self)
}
}
impl<T: AsRef<[u8]>> Serialize for Base32<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&base32::encode(
base32::Alphabet::RFC4648 { padding: true },
self.0.as_ref(),
))
}
}
pub struct Base64<T>(pub T);
impl<'de, T: TryFrom<Vec<u8>>> Deserialize<'de> for Base64<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
base64::decode(&s)
.map_err(serde::de::Error::custom)?
.try_into()
.map_err(|_| serde::de::Error::custom("invalid length"))
.map(Self)
}
}
impl<T: AsRef<[u8]>> Serialize for Base64<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&base64::encode(self.0.as_ref()))
}
}
#[derive(Clone, Debug)]
pub struct Regex(regex::Regex);
impl From<Regex> for regex::Regex {
fn from(value: Regex) -> Self {
value.0
}
}
impl From<regex::Regex> for Regex {
fn from(value: regex::Regex) -> Self {
Regex(value)
}
}
impl AsRef<regex::Regex> for Regex {
fn as_ref(&self) -> &regex::Regex {
&self.0
}
}
impl AsMut<regex::Regex> for Regex {
fn as_mut(&mut self) -> &mut regex::Regex {
&mut self.0
}
}
impl<'de> Deserialize<'de> for Regex {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserialize_from_str(deserializer).map(Self)
}
}
impl Serialize for Regex {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_display(&self.0, serializer)
}
}