mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-03-30 12:11:56 +00:00
fix: Cleanup by sending a command and kill when dropped (#1945)
* fix: Cleanup by sending a command and kill when dropped * chore: Fix the loadModule run command * fix: cleans up failed health * refactor long-running * chore: Fixes?" * refactor * run iso ci on pr * fix debuild * fix tests * switch to libc kill * kill process by parent * fix graceful shutdown * recurse submodules * fix compat build * feat: Add back in the timeout * chore: add the missing types for the unnstable * inherited logs Co-authored-by: J M <Blu-J@users.noreply.github.com> * fix deleted code Co-authored-by: Aiden McClelland <me@drbonez.dev> Co-authored-by: J M <Blu-J@users.noreply.github.com>
This commit is contained in:
@@ -1,104 +1,173 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::instrument;
|
||||
use nix::unistd::Pid;
|
||||
use serde::{Deserialize, Serialize, Serializer};
|
||||
use yajrc::RpcMethod;
|
||||
|
||||
/// The inputs that the executable is expecting
|
||||
pub type InputJsonRpc = JsonRpc<Input>;
|
||||
/// The outputs that the executable is expected to output
|
||||
pub type OutputJsonRpc = JsonRpc<Output>;
|
||||
|
||||
/// Based on the jsonrpc spec, but we are limiting the rpc to a subset
|
||||
#[derive(Debug, Serialize, Copy, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[serde(untagged)]
|
||||
pub enum RpcId {
|
||||
UInt(u32),
|
||||
}
|
||||
/// Know what the process is called
|
||||
#[derive(Debug, Serialize, Deserialize, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[derive(Debug, Serialize, Deserialize, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct ProcessId(pub u32);
|
||||
|
||||
/// We use the JSON rpc as the format to share between the stdin and stdout for the executable.
|
||||
/// Note: We are not allowing the id to not exist, used to ensure all pairs of messages are tracked
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
|
||||
pub struct JsonRpc<T> {
|
||||
id: RpcId,
|
||||
#[serde(flatten)]
|
||||
pub version_rpc: VersionRpc<T>,
|
||||
impl From<ProcessId> for Pid {
|
||||
fn from(pid: ProcessId) -> Self {
|
||||
Pid::from_raw(pid.0 as i32)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
|
||||
#[serde(tag = "jsonrpc", rename_all = "camelCase")]
|
||||
pub enum VersionRpc<T> {
|
||||
#[serde(rename = "2.0")]
|
||||
Two(T),
|
||||
impl From<Pid> for ProcessId {
|
||||
fn from(pid: Pid) -> Self {
|
||||
ProcessId(pid.as_raw() as u32)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> JsonRpc<T>
|
||||
where
|
||||
T: Serialize + for<'de> serde::Deserialize<'de> + std::fmt::Debug,
|
||||
{
|
||||
/// Using this to simplify creating this nested struct. Used for creating input mostly for executable stdin
|
||||
pub fn new(id: RpcId, body: T) -> Self {
|
||||
JsonRpc {
|
||||
id,
|
||||
version_rpc: VersionRpc::Two(body),
|
||||
}
|
||||
}
|
||||
/// Use this to get the data out of the probably destructed output
|
||||
pub fn into_pair(self) -> (RpcId, T) {
|
||||
let Self { id, version_rpc } = self;
|
||||
let VersionRpc::Two(body) = version_rpc;
|
||||
(id, body)
|
||||
}
|
||||
/// Used during the execution.
|
||||
#[instrument]
|
||||
pub fn maybe_serialize(&self) -> Option<String> {
|
||||
match serde_json::to_string(self) {
|
||||
Ok(x) => Some(x),
|
||||
Err(e) => {
|
||||
tracing::warn!("Could not stringify and skipping");
|
||||
tracing::debug!("{:?}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Used during the execution
|
||||
#[instrument]
|
||||
pub fn maybe_parse(s: &str) -> Option<Self> {
|
||||
match serde_json::from_str::<Self>(s) {
|
||||
Ok(a) => Some(a),
|
||||
Err(e) => {
|
||||
tracing::warn!("Could not parse and skipping: {}", s);
|
||||
tracing::debug!("{:?}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
impl From<i32> for ProcessId {
|
||||
fn from(pid: i32) -> Self {
|
||||
ProcessId(pid as u32)
|
||||
}
|
||||
}
|
||||
|
||||
/// Outputs embedded in the JSONRpc output of the executable.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
#[serde(tag = "method", content = "params", rename_all = "camelCase")]
|
||||
pub enum Output {
|
||||
/// Will be called almost right away and only once per RpcId. Indicates what
|
||||
/// was spawned in the container.
|
||||
ProcessId(ProcessId),
|
||||
/// This is the line buffered output of the command
|
||||
Line(String),
|
||||
/// This is some kind of error with the program
|
||||
Error(String),
|
||||
/// Indication that the command is done
|
||||
Done(Option<i32>),
|
||||
#[derive(Debug, Serialize, Deserialize, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct ProcessGroupId(pub u32);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum OutputStrategy {
|
||||
Inherit,
|
||||
Collect,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
#[serde(tag = "method", content = "params", rename_all = "camelCase")]
|
||||
pub enum Input {
|
||||
/// Create a new command, with the args
|
||||
Command { command: String, args: Vec<String> },
|
||||
/// Send the sigkill to the process
|
||||
Kill(),
|
||||
/// Send the sigterm to the process
|
||||
Term(),
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct RunCommand;
|
||||
impl Serialize for RunCommand {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
Serialize::serialize(Self.as_str(), serializer)
|
||||
}
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RunCommandParams {
|
||||
pub gid: Option<ProcessGroupId>,
|
||||
pub command: String,
|
||||
pub args: Vec<String>,
|
||||
pub output: OutputStrategy,
|
||||
}
|
||||
impl RpcMethod for RunCommand {
|
||||
type Params = RunCommandParams;
|
||||
type Response = ProcessId;
|
||||
fn as_str<'a>(&'a self) -> &'a str {
|
||||
"command"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ReadLineStdout;
|
||||
impl Serialize for ReadLineStdout {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
Serialize::serialize(Self.as_str(), serializer)
|
||||
}
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReadLineStdoutParams {
|
||||
pub pid: ProcessId,
|
||||
}
|
||||
impl RpcMethod for ReadLineStdout {
|
||||
type Params = ReadLineStdoutParams;
|
||||
type Response = String;
|
||||
fn as_str<'a>(&'a self) -> &'a str {
|
||||
"read-line-stdout"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ReadLineStderr;
|
||||
impl Serialize for ReadLineStderr {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
Serialize::serialize(Self.as_str(), serializer)
|
||||
}
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReadLineStderrParams {
|
||||
pub pid: ProcessId,
|
||||
}
|
||||
impl RpcMethod for ReadLineStderr {
|
||||
type Params = ReadLineStderrParams;
|
||||
type Response = String;
|
||||
fn as_str<'a>(&'a self) -> &'a str {
|
||||
"read-line-stderr"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Output;
|
||||
impl Serialize for Output {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
Serialize::serialize(Self.as_str(), serializer)
|
||||
}
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OutputParams {
|
||||
pub pid: ProcessId,
|
||||
}
|
||||
impl RpcMethod for Output {
|
||||
type Params = OutputParams;
|
||||
type Response = String;
|
||||
fn as_str<'a>(&'a self) -> &'a str {
|
||||
"output"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct SendSignal;
|
||||
impl Serialize for SendSignal {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
Serialize::serialize(Self.as_str(), serializer)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SendSignalParams {
|
||||
pub pid: ProcessId,
|
||||
pub signal: u32,
|
||||
}
|
||||
impl RpcMethod for SendSignal {
|
||||
type Params = SendSignalParams;
|
||||
type Response = ();
|
||||
fn as_str<'a>(&'a self) -> &'a str {
|
||||
"signal"
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct SignalGroup;
|
||||
impl Serialize for SignalGroup {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
Serialize::serialize(Self.as_str(), serializer)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SignalGroupParams {
|
||||
pub gid: ProcessGroupId,
|
||||
pub signal: u32,
|
||||
}
|
||||
impl RpcMethod for SignalGroup {
|
||||
type Params = SignalGroupParams;
|
||||
type Response = ();
|
||||
fn as_str<'a>(&'a self) -> &'a str {
|
||||
"signal-group"
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,299 +1,388 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::ops::DerefMut;
|
||||
use std::os::unix::process::ExitStatusExt;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_stream::stream;
|
||||
use embassy_container_init::{
|
||||
Input, InputJsonRpc, JsonRpc, Output, OutputJsonRpc, ProcessId, RpcId,
|
||||
OutputParams, OutputStrategy, ProcessGroupId, ProcessId, ReadLineStderrParams,
|
||||
ReadLineStdoutParams, RunCommandParams, SendSignalParams, SignalGroupParams,
|
||||
};
|
||||
use futures::{pin_mut, Stream, StreamExt};
|
||||
use futures::StreamExt;
|
||||
use helpers::NonDetachingJoinHandle;
|
||||
use nix::errno::Errno;
|
||||
use nix::sys::signal::Signal;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::{Child, Command};
|
||||
use tokio::process::{Child, ChildStderr, ChildStdout, Command};
|
||||
use tokio::select;
|
||||
use tokio::sync::{oneshot, Mutex};
|
||||
use tracing::instrument;
|
||||
use tokio::sync::{watch, Mutex};
|
||||
use yajrc::{Id, RpcError};
|
||||
|
||||
const MAX_COMMANDS: usize = 10;
|
||||
|
||||
enum DoneProgramStatus {
|
||||
Wait(Result<std::process::ExitStatus, std::io::Error>),
|
||||
Killed,
|
||||
}
|
||||
/// Created from the child and rpc, to prove that the cmd was the one who died
|
||||
struct DoneProgram {
|
||||
id: RpcId,
|
||||
status: DoneProgramStatus,
|
||||
/// Outputs embedded in the JSONRpc output of the executable.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(untagged)]
|
||||
enum Output {
|
||||
Command(ProcessId),
|
||||
ReadLineStdout(String),
|
||||
ReadLineStderr(String),
|
||||
Output(String),
|
||||
Signal,
|
||||
SignalGroup,
|
||||
}
|
||||
|
||||
/// Used to attach the running command with the rpc
|
||||
struct ChildAndRpc {
|
||||
id: RpcId,
|
||||
child: Child,
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "method", content = "params", rename_all = "kebab-case")]
|
||||
enum Input {
|
||||
/// Run a new command, with the args
|
||||
Command(RunCommandParams),
|
||||
// /// Get a line of stdout from the command
|
||||
// ReadLineStdout(ReadLineStdoutParams),
|
||||
// /// Get a line of stderr from the command
|
||||
// ReadLineStderr(ReadLineStderrParams),
|
||||
/// Get output of command
|
||||
Output(OutputParams),
|
||||
/// Send the sigterm to the process
|
||||
Signal(SendSignalParams),
|
||||
/// Signal a group of processes
|
||||
SignalGroup(SignalGroupParams),
|
||||
}
|
||||
|
||||
impl ChildAndRpc {
|
||||
fn new(id: RpcId, mut command: tokio::process::Command) -> ::std::io::Result<Self> {
|
||||
Ok(Self {
|
||||
id,
|
||||
child: command.spawn()?,
|
||||
#[derive(Deserialize)]
|
||||
struct IncomingRpc {
|
||||
id: Id,
|
||||
#[serde(flatten)]
|
||||
input: Input,
|
||||
}
|
||||
|
||||
struct ChildInfo {
|
||||
gid: Option<ProcessGroupId>,
|
||||
child: Arc<Mutex<Option<Child>>>,
|
||||
output: Option<InheritOutput>,
|
||||
}
|
||||
|
||||
struct InheritOutput {
|
||||
_thread: NonDetachingJoinHandle<()>,
|
||||
stdout: watch::Receiver<String>,
|
||||
stderr: watch::Receiver<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Handler {
|
||||
children: Arc<Mutex<BTreeMap<ProcessId, ChildInfo>>>,
|
||||
}
|
||||
impl Handler {
|
||||
fn new() -> Self {
|
||||
Handler {
|
||||
children: Arc::new(Mutex::new(BTreeMap::new())),
|
||||
}
|
||||
}
|
||||
async fn handle(&self, req: Input) -> Result<Output, RpcError> {
|
||||
Ok(match req {
|
||||
Input::Command(RunCommandParams {
|
||||
gid,
|
||||
command,
|
||||
args,
|
||||
output,
|
||||
}) => Output::Command(self.command(gid, command, args, output).await?),
|
||||
// Input::ReadLineStdout(ReadLineStdoutParams { pid }) => {
|
||||
// Output::ReadLineStdout(self.read_line_stdout(pid).await?)
|
||||
// }
|
||||
// Input::ReadLineStderr(ReadLineStderrParams { pid }) => {
|
||||
// Output::ReadLineStderr(self.read_line_stderr(pid).await?)
|
||||
// }
|
||||
Input::Output(OutputParams { pid }) => Output::Output(self.output(pid).await?),
|
||||
Input::Signal(SendSignalParams { pid, signal }) => {
|
||||
self.signal(pid, signal).await?;
|
||||
Output::Signal
|
||||
}
|
||||
Input::SignalGroup(SignalGroupParams { gid, signal }) => {
|
||||
self.signal_group(gid, signal).await?;
|
||||
Output::SignalGroup
|
||||
}
|
||||
})
|
||||
}
|
||||
fn id(&self) -> Option<ProcessId> {
|
||||
self.child.id().map(ProcessId)
|
||||
}
|
||||
async fn wait(&mut self) -> DoneProgram {
|
||||
let status = DoneProgramStatus::Wait(self.child.wait().await);
|
||||
DoneProgram {
|
||||
id: self.id.clone(),
|
||||
status,
|
||||
}
|
||||
}
|
||||
async fn kill(mut self) -> DoneProgram {
|
||||
if let Err(err) = self.child.kill().await {
|
||||
let id = &self.id;
|
||||
tracing::error!("Error while trying to kill a process {id:?}");
|
||||
tracing::debug!("{err:?}");
|
||||
}
|
||||
DoneProgram {
|
||||
id: self.id.clone(),
|
||||
status: DoneProgramStatus::Killed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Controlls the tracing + other io events
|
||||
/// Can get the inputs from stdin
|
||||
/// Can start a command from an intputrpc returning stream of outputs
|
||||
/// Can output to stdout
|
||||
#[derive(Debug, Clone)]
|
||||
struct Io {
|
||||
commands: Arc<Mutex<BTreeMap<RpcId, oneshot::Sender<()>>>>,
|
||||
ids: Arc<Mutex<BTreeMap<RpcId, ProcessId>>>,
|
||||
}
|
||||
|
||||
impl Io {
|
||||
fn start() -> Self {
|
||||
use tracing_error::ErrorLayer;
|
||||
use tracing_subscriber::prelude::*;
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
|
||||
let filter_layer = EnvFilter::new("embassy_container_init=debug");
|
||||
let fmt_layer = fmt::layer().with_target(true);
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(filter_layer)
|
||||
.with(fmt_layer)
|
||||
.with(ErrorLayer::default())
|
||||
.init();
|
||||
color_eyre::install().unwrap();
|
||||
Self {
|
||||
commands: Default::default(),
|
||||
ids: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument]
|
||||
fn command(&self, input: InputJsonRpc) -> impl Stream<Item = OutputJsonRpc> {
|
||||
let io = self.clone();
|
||||
stream! {
|
||||
let (id, command) = input.into_pair();
|
||||
match command {
|
||||
Input::Command {
|
||||
ref command,
|
||||
ref args,
|
||||
} => {
|
||||
let mut cmd = Command::new(command);
|
||||
cmd.args(args);
|
||||
|
||||
cmd.stdout(Stdio::piped());
|
||||
cmd.stderr(Stdio::piped());
|
||||
let mut child_and_rpc = match ChildAndRpc::new(id.clone(), cmd) {
|
||||
Err(_e) => return,
|
||||
Ok(a) => a,
|
||||
};
|
||||
|
||||
if let Some(child_id) = child_and_rpc.id() {
|
||||
io.ids.lock().await.insert(id.clone(), child_id.clone());
|
||||
yield JsonRpc::new(id.clone(), Output::ProcessId(child_id));
|
||||
}
|
||||
|
||||
let stdout = child_and_rpc.child
|
||||
.stdout
|
||||
.take()
|
||||
.expect("child did not have a handle to stdout");
|
||||
let stderr = child_and_rpc.child
|
||||
.stderr
|
||||
.take()
|
||||
.expect("child did not have a handle to stderr");
|
||||
|
||||
let mut buff_out = BufReader::new(stdout).lines();
|
||||
let mut buff_err = BufReader::new(stderr).lines();
|
||||
|
||||
let spawned = tokio::spawn({
|
||||
let id = id.clone();
|
||||
async move {
|
||||
let end_command_receiver = io.create_end_command(id.clone()).await;
|
||||
tokio::select!{
|
||||
waited = child_and_rpc
|
||||
.wait() => {
|
||||
io.clean_id(&waited).await;
|
||||
match &waited.status {
|
||||
DoneProgramStatus::Wait(Ok(st)) => return st.code(),
|
||||
DoneProgramStatus::Wait(Err(err)) => tracing::debug!("Child {id:?} got error: {err:?}"),
|
||||
DoneProgramStatus::Killed => tracing::debug!("Child {id:?} already killed?"),
|
||||
async fn command(
|
||||
&self,
|
||||
gid: Option<ProcessGroupId>,
|
||||
command: String,
|
||||
args: Vec<String>,
|
||||
output: OutputStrategy,
|
||||
) -> Result<ProcessId, RpcError> {
|
||||
let mut cmd = Command::new(command);
|
||||
cmd.args(args);
|
||||
cmd.kill_on_drop(true);
|
||||
cmd.stdout(Stdio::piped());
|
||||
cmd.stderr(Stdio::piped());
|
||||
let mut child = cmd.spawn().map_err(|e| {
|
||||
let mut err = yajrc::INTERNAL_ERROR.clone();
|
||||
err.data = Some(json!(e.to_string()));
|
||||
err
|
||||
})?;
|
||||
let pid = ProcessId(child.id().ok_or_else(|| {
|
||||
let mut err = yajrc::INTERNAL_ERROR.clone();
|
||||
err.data = Some(json!("Child has no pid"));
|
||||
err
|
||||
})?);
|
||||
let output = match output {
|
||||
OutputStrategy::Inherit => {
|
||||
let (stdout_send, stdout) = watch::channel(String::new());
|
||||
let (stderr_send, stderr) = watch::channel(String::new());
|
||||
if let (Some(child_stdout), Some(child_stderr)) =
|
||||
(child.stdout.take(), child.stderr.take())
|
||||
{
|
||||
Some(InheritOutput {
|
||||
_thread: tokio::spawn(async move {
|
||||
tokio::join!(
|
||||
async {
|
||||
if let Err(e) = async {
|
||||
let mut lines = BufReader::new(child_stdout).lines();
|
||||
while let Some(line) = lines.next_line().await? {
|
||||
tracing::info!("({}): {}", pid.0, line);
|
||||
let _ = stdout_send.send(line);
|
||||
}
|
||||
|
||||
},
|
||||
_ = end_command_receiver => {
|
||||
let status = child_and_rpc.kill().await;
|
||||
io.clean_id(&status).await;
|
||||
Ok::<_, std::io::Error>(())
|
||||
}
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
"Error reading stdout of pid {}: {}",
|
||||
pid.0,
|
||||
e
|
||||
);
|
||||
}
|
||||
},
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
});
|
||||
while let Ok(Some(line)) = buff_out.next_line().await {
|
||||
let output = Output::Line(line);
|
||||
let output = JsonRpc::new(id.clone(), output);
|
||||
tracing::trace!("OutputJsonRpc {{ id, output_rpc }} = {:?}", output);
|
||||
yield output;
|
||||
}
|
||||
while let Ok(Some(line)) = buff_err.next_line().await {
|
||||
yield JsonRpc::new(id.clone(), Output::Error(line));
|
||||
}
|
||||
let code = spawned.await.ok().flatten();
|
||||
yield JsonRpc::new(id, Output::Done(code));
|
||||
},
|
||||
Input::Kill() => {
|
||||
io.trigger_end_command(id).await;
|
||||
async {
|
||||
if let Err(e) = async {
|
||||
let mut lines = BufReader::new(child_stderr).lines();
|
||||
while let Some(line) = lines.next_line().await? {
|
||||
tracing::warn!("({}): {}", pid.0, line);
|
||||
let _ = stderr_send.send(line);
|
||||
}
|
||||
Ok::<_, std::io::Error>(())
|
||||
}
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
"Error reading stdout of pid {}: {}",
|
||||
pid.0,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
);
|
||||
})
|
||||
.into(),
|
||||
stdout,
|
||||
stderr,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
Input::Term() => {
|
||||
io.term_by_rpc(&id).await;
|
||||
}
|
||||
OutputStrategy::Collect => None,
|
||||
};
|
||||
self.children.lock().await.insert(
|
||||
pid,
|
||||
ChildInfo {
|
||||
gid,
|
||||
child: Arc::new(Mutex::new(Some(child))),
|
||||
output,
|
||||
},
|
||||
);
|
||||
Ok(pid)
|
||||
}
|
||||
|
||||
async fn output(&self, pid: ProcessId) -> Result<String, RpcError> {
|
||||
let not_found = || {
|
||||
let mut err = yajrc::INTERNAL_ERROR.clone();
|
||||
err.data = Some(json!(format!("Child with pid {} not found", pid.0)));
|
||||
err
|
||||
};
|
||||
let mut child = {
|
||||
self.children
|
||||
.lock()
|
||||
.await
|
||||
.get(&pid)
|
||||
.ok_or_else(not_found)?
|
||||
.child
|
||||
.clone()
|
||||
}
|
||||
.lock_owned()
|
||||
.await;
|
||||
if let Some(child) = child.take() {
|
||||
let output = child.wait_with_output().await?;
|
||||
if output.status.success() {
|
||||
Ok(String::from_utf8(output.stdout).map_err(|_| yajrc::PARSE_ERROR)?)
|
||||
} else {
|
||||
Err(RpcError {
|
||||
code: output
|
||||
.status
|
||||
.code()
|
||||
.or_else(|| output.status.signal().map(|s| 128 + s))
|
||||
.unwrap_or(0),
|
||||
message: "Command failed".into(),
|
||||
data: Some(json!(String::from_utf8(if output.stderr.is_empty() {
|
||||
output.stdout
|
||||
} else {
|
||||
output.stderr
|
||||
})
|
||||
.map_err(|_| yajrc::PARSE_ERROR)?)),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
Err(not_found())
|
||||
}
|
||||
}
|
||||
|
||||
async fn signal(&self, pid: ProcessId, signal: u32) -> Result<(), RpcError> {
|
||||
let not_found = || {
|
||||
let mut err = yajrc::INTERNAL_ERROR.clone();
|
||||
err.data = Some(json!(format!("Child with pid {} not found", pid.0)));
|
||||
err
|
||||
};
|
||||
|
||||
Self::killall(pid, Signal::try_from(signal as i32)?)?;
|
||||
|
||||
if signal == 9 {
|
||||
self.children
|
||||
.lock()
|
||||
.await
|
||||
.remove(&pid)
|
||||
.ok_or_else(not_found)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn signal_group(&self, gid: ProcessGroupId, signal: u32) -> Result<(), RpcError> {
|
||||
let mut to_kill = Vec::new();
|
||||
{
|
||||
let mut children_ref = self.children.lock().await;
|
||||
let children = std::mem::take(children_ref.deref_mut());
|
||||
for (pid, child_info) in children {
|
||||
if child_info.gid == Some(gid) {
|
||||
to_kill.push(pid);
|
||||
} else {
|
||||
children_ref.insert(pid, child_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
for pid in to_kill {
|
||||
tracing::info!("Killing pid {}", pid.0);
|
||||
Self::killall(pid, Signal::try_from(signal as i32)?)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
/// Used to get the string lines from the stdin
|
||||
fn inputs(&self) -> impl Stream<Item = String> {
|
||||
use std::io::BufRead;
|
||||
let (sender, receiver) = tokio::sync::mpsc::channel(100);
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let stdin = std::io::stdin();
|
||||
for line in stdin.lock().lines().flatten() {
|
||||
tracing::trace!("Line = {}", line);
|
||||
sender.blocking_send(line).unwrap();
|
||||
|
||||
fn killall(pid: ProcessId, signal: Signal) -> Result<(), RpcError> {
|
||||
for proc in procfs::process::all_processes()? {
|
||||
let stat = proc?.stat()?;
|
||||
if ProcessId::from(stat.ppid) == pid {
|
||||
Self::killall(stat.pid.into(), signal)?;
|
||||
}
|
||||
}
|
||||
if let Err(e) = nix::sys::signal::kill(pid.into(), Some(signal)) {
|
||||
if e != Errno::ESRCH {
|
||||
tracing::error!("Failed to kill pid {}: {}", pid.0, e);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn graceful_exit(self) {
|
||||
let kill_all = futures::stream::iter(
|
||||
std::mem::take(self.children.lock().await.deref_mut()).into_iter(),
|
||||
)
|
||||
.for_each_concurrent(None, |(pid, child)| async move {
|
||||
let _ = Self::killall(pid, Signal::SIGTERM);
|
||||
if let Some(child) = child.child.lock().await.take() {
|
||||
let _ = child.wait_with_output().await;
|
||||
}
|
||||
});
|
||||
tokio_stream::wrappers::ReceiverStream::new(receiver)
|
||||
}
|
||||
|
||||
///Convert a stream of string to stdout
|
||||
async fn output(&self, outputs: impl Stream<Item = String>) {
|
||||
pin_mut!(outputs);
|
||||
while let Some(output) = outputs.next().await {
|
||||
println!("{}", output);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for the command fn
|
||||
/// Part of a pair for the signal map, that indicates that we should kill the command
|
||||
async fn trigger_end_command(&self, id: RpcId) {
|
||||
if let Some(command) = self.commands.lock().await.remove(&id) {
|
||||
if command.send(()).is_err() {
|
||||
tracing::trace!("Command {id:?} could not be ended, possible error or was done");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for the command fn
|
||||
/// Part of a pair for the signal map, that indicates that we should kill the command
|
||||
async fn create_end_command(&self, id: RpcId) -> oneshot::Receiver<()> {
|
||||
let (send, receiver) = oneshot::channel();
|
||||
if let Some(other_command) = self.commands.lock().await.insert(id.clone(), send) {
|
||||
if other_command.send(()).is_err() {
|
||||
tracing::trace!(
|
||||
"Found other command {id:?} could not be ended, possible error or was done"
|
||||
);
|
||||
}
|
||||
}
|
||||
receiver
|
||||
}
|
||||
|
||||
/// Used during cleaning up a procress
|
||||
async fn clean_id(
|
||||
&self,
|
||||
done_program: &DoneProgram,
|
||||
) -> (Option<ProcessId>, Option<oneshot::Sender<()>>) {
|
||||
(
|
||||
self.ids.lock().await.remove(&done_program.id),
|
||||
self.commands.lock().await.remove(&done_program.id),
|
||||
)
|
||||
}
|
||||
|
||||
/// Given the rpcid, will try and term the running command
|
||||
async fn term_by_rpc(&self, rpc: &RpcId) {
|
||||
let output = match self.remove_cmd_id(rpc).await {
|
||||
Some(id) => {
|
||||
let mut cmd = tokio::process::Command::new("kill");
|
||||
cmd.arg(format!("{}", id.0));
|
||||
cmd.output().await
|
||||
}
|
||||
None => return,
|
||||
};
|
||||
match output {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
tracing::error!("Could not kill rpc {rpc:?}");
|
||||
tracing::debug!("{err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Used as a cleanup
|
||||
async fn term_all(self) {
|
||||
let ids: Vec<_> = self.ids.lock().await.keys().cloned().collect();
|
||||
for id in ids {
|
||||
self.term_by_rpc(&id).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn remove_cmd_id(&self, rpc: &RpcId) -> Option<ProcessId> {
|
||||
self.ids.lock().await.remove(rpc)
|
||||
kill_all.await
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
use futures::StreamExt;
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
let mut sigint = signal(SignalKind::interrupt()).unwrap();
|
||||
let mut sigterm = signal(SignalKind::terminate()).unwrap();
|
||||
let mut sigquit = signal(SignalKind::quit()).unwrap();
|
||||
let mut sighangup = signal(SignalKind::hangup()).unwrap();
|
||||
let io = Io::start();
|
||||
let outputs = io
|
||||
.inputs()
|
||||
.filter_map(|x| async move { InputJsonRpc::maybe_parse(&x) })
|
||||
.flat_map_unordered(MAX_COMMANDS, |x| io.command(x).boxed())
|
||||
.filter_map(|x| async move { x.maybe_serialize() });
|
||||
|
||||
use tracing_error::ErrorLayer;
|
||||
use tracing_subscriber::prelude::*;
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
|
||||
let filter_layer = EnvFilter::new("embassy_container_init=debug");
|
||||
let fmt_layer = fmt::layer().with_target(true);
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(filter_layer)
|
||||
.with(fmt_layer)
|
||||
.with(ErrorLayer::default())
|
||||
.init();
|
||||
color_eyre::install().unwrap();
|
||||
|
||||
let handler = Handler::new();
|
||||
let mut lines = BufReader::new(tokio::io::stdin()).lines();
|
||||
let handler_thread = async {
|
||||
while let Some(line) = lines.next_line().await? {
|
||||
let local_hdlr = handler.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = async {
|
||||
eprintln!("{}", line);
|
||||
let req = serde_json::from_str::<IncomingRpc>(&line)?;
|
||||
match local_hdlr.handle(req.input).await {
|
||||
Ok(output) => {
|
||||
println!(
|
||||
"{}",
|
||||
json!({ "id": req.id, "jsonrpc": "2.0", "result": output })
|
||||
)
|
||||
}
|
||||
Err(e) => {
|
||||
println!("{}", json!({ "id": req.id, "jsonrpc": "2.0", "error": e }))
|
||||
}
|
||||
}
|
||||
Ok::<_, serde_json::Error>(())
|
||||
}
|
||||
.await
|
||||
{
|
||||
tracing::error!("Error parsing RPC request: {}", e);
|
||||
tracing::debug!("{:?}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok::<_, std::io::Error>(())
|
||||
};
|
||||
|
||||
select! {
|
||||
_ = io.output(outputs) => {
|
||||
tracing::debug!("Done with inputs/outputs")
|
||||
res = handler_thread => {
|
||||
match res {
|
||||
Ok(()) => tracing::debug!("Done with inputs/outputs"),
|
||||
Err(e) => {
|
||||
tracing::error!("Error reading RPC input: {}", e);
|
||||
tracing::debug!("{:?}", e);
|
||||
}
|
||||
}
|
||||
},
|
||||
_ = sigint.recv() => {
|
||||
tracing::debug!("Sigint")
|
||||
tracing::debug!("SIGINT");
|
||||
},
|
||||
_ = sigterm.recv() => {
|
||||
tracing::debug!("Sig Term")
|
||||
tracing::debug!("SIGTERM");
|
||||
},
|
||||
_ = sigquit.recv() => {
|
||||
tracing::debug!("Sigquit")
|
||||
tracing::debug!("SIGQUIT");
|
||||
},
|
||||
_ = sighangup.recv() => {
|
||||
tracing::debug!("Sighangup")
|
||||
tracing::debug!("SIGHUP");
|
||||
}
|
||||
}
|
||||
io.term_all().await;
|
||||
::std::process::exit(0);
|
||||
handler.graceful_exit().await;
|
||||
::std::process::exit(0)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user