diff --git a/Cargo.lock b/Cargo.lock index 208f7ae..cafdfe2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -103,7 +103,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -114,7 +114,7 @@ checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -196,6 +196,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfaff671f6b22ca62406885ece523383b9b64022e341e53e009a62ebc47a45f2" dependencies = [ "clap_builder", + "clap_derive", ] [[package]] @@ -210,6 +211,18 @@ dependencies = [ "strsim", ] +[[package]] +name = "clap_derive" +version = "4.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.41", +] + [[package]] name = "clap_lex" version = "0.6.0" @@ -355,7 +368,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -464,6 +477,12 @@ version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "hermit-abi" version = "0.3.3" @@ -513,6 +532,19 @@ dependencies = [ "http 1.0.0", ] +[[package]] +name = "http-body-util" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840" +dependencies = [ + "bytes", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "pin-project-lite", +] + [[package]] name = "httparse" version = "1.8.0" @@ -790,7 +822,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -956,6 +988,8 @@ dependencies = [ "async-trait", "clap", "futures", + "http 1.0.0", + "http-body-util", "hyper 1.0.1", "imbl-value", "lazy_static", @@ -967,6 +1001,7 @@ dependencies = [ "serde_json", "thiserror", "tokio", + "tokio-stream", "url", "yajrc", ] @@ -1079,7 +1114,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -1168,9 +1203,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.40" +version = "2.0.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13fa70a4ee923979ffb522cacce59d34421ebdea5625e1073c4326ef9d2dd42e" +checksum = "44c8b28c477cc3bf0e7966561e3460130e1255f7a1cf71931075f1c5e7a7e269" dependencies = [ "proc-macro2", "quote", @@ -1228,7 +1263,7 @@ checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -1273,7 +1308,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -1286,6 +1321,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.10" @@ -1417,7 +1463,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", "wasm-bindgen-shared", ] @@ -1451,7 +1497,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1638,9 +1684,9 @@ dependencies = [ [[package]] name = "yajrc" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc08b562507a1674a1ef886a1aedeeb19d41462386ae09f634995d41bbef87d3" +checksum = "47cb33cb21fb6923a0dd074fd20dfd98fc3758103b7e2607db1354b4a86ef37c" dependencies = [ "anyhow", "serde", @@ -1677,5 +1723,5 @@ checksum = "be912bf68235a88fbefd1b73415cb218405958d1655b2ece9035a19920bdf6ba" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] diff --git a/rpc-toolkit/Cargo.toml b/rpc-toolkit/Cargo.toml index aefba8e..cc97b2a 100644 --- a/rpc-toolkit/Cargo.toml +++ b/rpc-toolkit/Cargo.toml @@ -18,8 +18,10 @@ default = ["cbor"] [dependencies] async-stream = "0.3" async-trait = "0.1" -clap = "4" +clap = { version = "4", features = ["derive"] } futures = "0.3" +http = "1" +http-body-util = "0.1" hyper = { version = "1", features = ["server", "http1", "http2", "client"] } imbl-value = "0.1" lazy_static = "1.4" @@ -31,5 +33,6 @@ serde_cbor = { version = "0.11", optional = true } serde_json = "1.0" thiserror = "1.0" tokio = { version = "1", features = ["full"] } +tokio-stream = { version = "0.1", features = ["io-util", "net"] } url = "2" yajrc = "0.1" diff --git a/rpc-toolkit/src/cli.rs b/rpc-toolkit/src/cli.rs index 8afdc6d..a696105 100644 --- a/rpc-toolkit/src/cli.rs +++ b/rpc-toolkit/src/cli.rs @@ -7,12 +7,17 @@ use imbl_value::Value; use reqwest::{Client, Method}; use serde::de::DeserializeOwned; use serde::Serialize; +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; use url::Url; -use yajrc::{GenericRpcMethod, Id, RpcError, RpcRequest}; +use yajrc::{Id, RpcError}; use crate::command::{AsyncCommand, DynCommand, LeafCommand, ParentInfo}; -use crate::util::{combine, invalid_params, parse_error}; -use crate::{CliBindings, ParentChain}; +use crate::util::{combine, internal_error, invalid_params, parse_error}; +use crate::{CliBindings, SyncCommand}; + +type GenericRpcMethod<'a> = yajrc::GenericRpcMethod<&'a str, Value, Value>; +type RpcRequest<'a> = yajrc::RpcRequest>; +type RpcResponse<'a> = yajrc::RpcResponse>; impl DynCommand { fn cli_app(&self) -> Option { @@ -55,7 +60,7 @@ impl DynCommand { } struct CliApp { - cli: CliBindings, + cli: CliBindings, commands: Vec>, } impl CliApp { @@ -110,6 +115,8 @@ impl CliAppAsync { }), } } +} +impl CliAppAsync { pub async fn run(self, args: Vec) -> Result<(), RpcError> { let cmd = self .app @@ -130,10 +137,10 @@ impl CliAppAsync { .implementation .as_ref() .ok_or(yajrc::METHOD_NOT_FOUND_ERROR)? - .async_impl)(ctx, parent_method, params) + .async_impl)(ctx.clone(), parent_method.clone(), params.clone()) .await?; if let Some(display) = display { - display(res).map_err(parse_error) + display(ctx, parent_method, params, res).map_err(parse_error) } else { Ok(()) } @@ -157,6 +164,8 @@ impl CliAppSync { make_ctx: Box::new(|args| make_ctx(imbl_value::from_value(args).map_err(parse_error)?)), } } +} +impl CliAppSync { pub async fn run(self, args: Vec) -> Result<(), RpcError> { let cmd = self .app @@ -177,9 +186,9 @@ impl CliAppSync { .implementation .as_ref() .ok_or(yajrc::METHOD_NOT_FOUND_ERROR)? - .sync_impl)(ctx, parent_method, params)?; + .sync_impl)(ctx.clone(), parent_method.clone(), params.clone())?; if let Some(display) = display { - display(res).map_err(parse_error) + display(ctx, parent_method, params, res).map_err(parse_error) } else { Ok(()) } @@ -191,14 +200,12 @@ pub trait CliContext: crate::Context { async fn call_remote(&self, method: &str, params: Value) -> Result; } +#[async_trait::async_trait] pub trait CliContextHttp: crate::Context { fn client(&self) -> &Client; fn url(&self) -> Url; -} -#[async_trait::async_trait] -impl CliContext for T { async fn call_remote(&self, method: &str, params: Value) -> Result { - let rpc_req: RpcRequest> = RpcRequest { + let rpc_req = RpcRequest { id: Some(Id::Number(0.into())), method: GenericRpcMethod::new(method), params, @@ -222,33 +229,61 @@ impl CliContext for T { .body(body) .send() .await?; - Ok( - match res - .headers() - .get("content-type") - .and_then(|v| v.to_str().ok()) - { - Some("application/json") => serde_json::from_slice(&*res.bytes().await?)?, - #[cfg(feature = "cbor")] - Some("application/cbor") => serde_cbor::from_slice(&*res.bytes().await?)?, - _ => { - return Err(RpcError { - data: Some("missing content type".into()), - ..yajrc::INTERNAL_ERROR - }) - } - }, - ) + + match res + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + { + Some("application/json") => { + serde_json::from_slice::(&*res.bytes().await.map_err(internal_error)?) + .map_err(parse_error)? + .result + } + #[cfg(feature = "cbor")] + Some("application/cbor") => { + serde_cbor::from_slice::(&*res.bytes().await.map_err(internal_error)?) + .map_err(parse_error)? + .result + } + _ => Err(internal_error("missing content type")), + } } } -pub trait RemoteCommand: LeafCommand { - fn metadata() -> Context::Metadata; - fn subcommands(chain: ParentChain) -> Vec> { - drop(chain); - Vec::new() +#[async_trait::async_trait] +pub trait CliContextSocket: crate::Context { + type Stream: AsyncRead + AsyncWrite + Send; + async fn connect(&self) -> std::io::Result; + async fn call_remote(&self, method: &str, params: Value) -> Result { + let rpc_req = RpcRequest { + id: Some(Id::Number(0.into())), + method: GenericRpcMethod::new(method), + params, + }; + let conn = self.connect().await.map_err(|e| RpcError { + data: Some(e.to_string().into()), + ..yajrc::INTERNAL_ERROR + })?; + tokio::pin!(conn); + let mut buf = serde_json::to_vec(&rpc_req).map_err(|e| RpcError { + data: Some(e.to_string().into()), + ..yajrc::INTERNAL_ERROR + })?; + buf.push(b'\n'); + conn.write_all(&buf).await.map_err(|e| RpcError { + data: Some(e.to_string().into()), + ..yajrc::INTERNAL_ERROR + })?; + let mut line = String::new(); + BufReader::new(conn).read_line(&mut line).await?; + serde_json::from_str::(&line) + .map_err(parse_error)? + .result } } + +pub trait RemoteCommand: LeafCommand {} #[async_trait::async_trait] impl AsyncCommand for T where @@ -258,9 +293,6 @@ where T::Err: From, Context: CliContext + Send + 'static, { - fn metadata() -> Context::Metadata { - T::metadata() - } async fn implementation( self, ctx: Context, @@ -280,7 +312,33 @@ where ) .map_err(parse_error)?) } - fn subcommands(chain: ParentChain) -> Vec> { - T::subcommands(chain) +} + +impl SyncCommand for T +where + T: RemoteCommand + Send + Serialize, + T::Parent: Serialize, + T::Ok: DeserializeOwned, + T::Err: From, + Context: CliContext + Send + 'static, +{ + const BLOCKING: bool = true; + fn implementation( + self, + ctx: Context, + parent: ParentInfo, + ) -> Result { + let mut method = parent.method; + method.push(Self::NAME); + Ok( + imbl_value::from_value(ctx.runtime().block_on(ctx.call_remote( + &method.join("."), + combine( + imbl_value::to_value(&self).map_err(invalid_params)?, + imbl_value::to_value(&parent.params).map_err(invalid_params)?, + )?, + ))?) + .map_err(parse_error)?, + ) } } diff --git a/rpc-toolkit/src/command.rs b/rpc-toolkit/src/command.rs index 05e2549..3dabecf 100644 --- a/rpc-toolkit/src/command.rs +++ b/rpc-toolkit/src/command.rs @@ -17,7 +17,7 @@ pub struct DynCommand { pub(crate) name: &'static str, pub(crate) metadata: Context::Metadata, pub(crate) implementation: Option>, - pub(crate) cli: Option, + pub(crate) cli: Option>, pub(crate) subcommands: Vec, } @@ -31,12 +31,18 @@ pub(crate) struct Implementation { Box, Value) -> Result + Send + Sync>, } -pub(crate) struct CliBindings { +pub(crate) struct CliBindings { pub(crate) cmd: clap::Command, pub(crate) parser: Box Fn(&'a ArgMatches) -> Result + Send + Sync>, - pub(crate) display: Option Result<(), imbl_value::Error> + Send + Sync>>, + pub(crate) display: Option< + Box< + dyn Fn(Context, Vec<&'static str>, Value, Value) -> Result<(), imbl_value::Error> + + Send + + Sync, + >, + >, } -impl CliBindings { +impl CliBindings { pub(crate) fn from_parent() -> Self { Self { cmd: Cmd::command(), @@ -53,10 +59,19 @@ impl CliBindings { display: None, } } - fn from_leaf() -> Self { + fn from_leaf>() -> Self + { Self { - display: Some(Box::new(|res| { - Ok(Cmd::display(imbl_value::from_value(res)?)) + display: Some(Box::new(|ctx, parent_method, params, res| { + let parent_params = imbl_value::from_value(params.clone())?; + Ok(imbl_value::from_value::(params)?.display( + ctx, + ParentInfo { + method: parent_method, + params: parent_params, + }, + imbl_value::from_value(res)?, + )) })), ..Self::from_parent::() } @@ -113,13 +128,18 @@ where /// Implement this for a command that has no implementation, but simply exists to organize subcommands pub trait ParentCommand: Command { - fn metadata() -> Context::Metadata; + fn metadata() -> Context::Metadata { + Context::Metadata::default() + } fn subcommands(chain: ParentChain) -> Vec>; } impl DynCommand { pub fn from_parent< Cmd: ParentCommand + FromArgMatches + CommandFactory + Serialize, - >() -> Self { + >( + contains: Contains, + ) -> Self { + drop(contains); Self { name: Cmd::NAME, metadata: Cmd::metadata(), @@ -128,7 +148,10 @@ impl DynCommand { subcommands: Cmd::subcommands(ParentChain::(PhantomData)), } } - pub fn from_parent_no_cli>() -> Self { + pub fn from_parent_no_cli>( + contains: Contains, + ) -> Self { + drop(contains); Self { name: Cmd::NAME, metadata: Cmd::metadata(), @@ -140,25 +163,27 @@ impl DynCommand { } /// Implement this for any command with an implementation -pub trait LeafCommand: Command { +pub trait LeafCommand: Command { type Ok: DeserializeOwned + Serialize + Send; type Err: From + Into + Send; - fn display(res: Self::Ok); + fn metadata() -> Context::Metadata { + Context::Metadata::default() + } + fn display(self, ctx: Context, parent: ParentInfo, res: Self::Ok); + fn subcommands(chain: ParentChain) -> Vec> { + drop(chain); + Vec::new() + } } /// Implement this if your Command's implementation is async #[async_trait::async_trait] -pub trait AsyncCommand: LeafCommand { - fn metadata() -> Context::Metadata; +pub trait AsyncCommand: LeafCommand { async fn implementation( self, ctx: Context, parent: ParentInfo, ) -> Result; - fn subcommands(chain: ParentChain) -> Vec> { - drop(chain); - Vec::new() - } } impl Implementation { fn for_async>(contains: Contains) -> Self { @@ -238,18 +263,13 @@ impl DynCommand { } /// Implement this if your Command's implementation is not async -pub trait SyncCommand: LeafCommand { +pub trait SyncCommand: LeafCommand { const BLOCKING: bool; - fn metadata() -> Context::Metadata; fn implementation( self, ctx: Context, parent: ParentInfo, ) -> Result; - fn subcommands(chain: ParentChain) -> Vec> { - drop(chain); - Vec::new() - } } impl Implementation { fn for_sync>(contains: Contains) -> Self { diff --git a/rpc-toolkit/src/context.rs b/rpc-toolkit/src/context.rs index b718129..2c47228 100644 --- a/rpc-toolkit/src/context.rs +++ b/rpc-toolkit/src/context.rs @@ -1,7 +1,7 @@ use tokio::runtime::Handle; pub trait Context: Send + 'static { - type Metadata: Default; + type Metadata: Default + Send + Sync; fn runtime(&self) -> Handle { Handle::current() } diff --git a/rpc-toolkit/src/server/http.rs b/rpc-toolkit/src/server/http.rs index e69de29..29bdee6 100644 --- a/rpc-toolkit/src/server/http.rs +++ b/rpc-toolkit/src/server/http.rs @@ -0,0 +1,60 @@ +use std::task::Context; + +use futures::future::BoxFuture; +use http::request::Parts; +use hyper::body::{Bytes, Incoming}; +use hyper::{Request, Response}; +use yajrc::{RpcRequest, RpcResponse}; + +type BoxBody = http_body_util::combinators::BoxBody; + +#[async_trait::async_trait] +pub trait Middleware { + type ProcessHttpRequestResult; + async fn process_http_request( + &self, + req: &mut Request, + ) -> Result>>; + type ProcessRpcRequestResult; + async fn process_rpc_request( + &self, + prev: Self::ProcessHttpRequestResult, + metadata: &Context::Metadata, + req: &mut RpcRequest, + ) -> Result; + type ProcessRpcResponseResult; + async fn process_rpc_response( + &self, + prev: Self::ProcessRpcRequestResult, + res: &mut RpcResponse, + ) -> Self::ProcessRpcResponseResult; + async fn process_http_response( + &self, + prev: Self::ProcessRpcResponseResult, + res: &mut Response, + ); +} + +// pub struct DynMiddleware { +// process_http_request: Box< +// dyn for<'a> Fn( +// &'a mut Request, +// ) -> BoxFuture< +// 'a, +// Result, hyper::Result>>, +// > + Send +// + Sync, +// >, +// } +// type DynProcessRpcRequest<'m, Context: crate::Context> = Box< +// dyn for<'a> FnOnce( +// &'a Context::Metadata, +// &'a mut RpcRequest, +// ) +// -> BoxFuture<'a, Result, DynSkipHandler<'m>>> +// + Send +// + Sync +// + 'm, +// >; +// type DynProcessRpcResponse<'m> = +// Box FnOnce(&'a mut RpcResponse) -> BoxFuture<'a, DynProcessHttpResponse<'m>>>; diff --git a/rpc-toolkit/src/server/mod.rs b/rpc-toolkit/src/server/mod.rs index be25722..2ee1ca1 100644 --- a/rpc-toolkit/src/server/mod.rs +++ b/rpc-toolkit/src/server/mod.rs @@ -1,16 +1,19 @@ +use std::borrow::Cow; use std::sync::Arc; use futures::future::{join_all, BoxFuture}; -use futures::stream::{BoxStream, Fuse}; -use futures::{Future, FutureExt, Stream, StreamExt, TryStreamExt}; +use futures::{Future, FutureExt, Stream, StreamExt}; use imbl_value::Value; -use tokio::runtime::Handle; -use tokio::task::JoinHandle; -use yajrc::{AnyParams, RpcError, RpcMethod, RpcRequest, RpcResponse, SingleOrBatchRpcRequest}; +use yajrc::{AnyParams, AnyRpcMethod, RpcError, RpcMethod}; -use crate::util::{invalid_request, parse_error}; +use crate::util::{invalid_request, JobRunner}; use crate::DynCommand; +type GenericRpcMethod = yajrc::GenericRpcMethod; +type RpcRequest = yajrc::RpcRequest; +type RpcResponse = yajrc::RpcResponse; +type SingleOrBatchRpcRequest = yajrc::SingleOrBatchRpcRequest; + mod http; mod socket; @@ -91,134 +94,58 @@ impl Server { &self, RpcRequest { id, method, params }: RpcRequest, ) -> impl Future + Send + 'static { - let handle = (|| { - Ok::<_, RpcError>(self.handle_command( - method.as_str(), - match params { - AnyParams::Named(a) => serde_json::Value::Object(a).into(), - _ => { - return Err(RpcError { - data: Some("positional parameters unsupported".into()), - ..yajrc::INVALID_PARAMS_ERROR - }) - } - }, - )) - })(); + let handle = (|| Ok::<_, RpcError>(self.handle_command(method.as_str(), params)))(); async move { RpcResponse { id, result: match handle { - Ok(handle) => handle.await.map(serde_json::Value::from), + Ok(handle) => handle.await, Err(e) => Err(e), }, } } } - pub fn handle(&self, request: Value) -> BoxFuture<'static, Result> { - let request = - imbl_value::from_value::(request).map_err(invalid_request); - match request { + pub fn handle( + &self, + request: Result, + ) -> BoxFuture<'static, Result> { + match request.and_then(|request| { + imbl_value::from_value::(request).map_err(invalid_request) + }) { Ok(SingleOrBatchRpcRequest::Single(req)) => { let fut = self.handle_single_request(req); - async { imbl_value::to_value(&fut.await).map_err(parse_error) }.boxed() + async { imbl_value::to_value(&fut.await) }.boxed() } Ok(SingleOrBatchRpcRequest::Batch(reqs)) => { let futs: Vec<_> = reqs .into_iter() .map(|req| self.handle_single_request(req)) .collect(); - async { imbl_value::to_value(&join_all(futs).await).map_err(parse_error) }.boxed() + async { imbl_value::to_value(&join_all(futs).await) }.boxed() } - Err(e) => async { Err(e) }.boxed(), + Err(e) => async { + imbl_value::to_value(&RpcResponse { + id: None, + result: Err(e), + }) + } + .boxed(), } } pub fn stream<'a>( &'a self, requests: impl Stream> + Send + 'a, - ) -> impl Stream> + 'a { - let mut running = RunningCommands::default(); - let mut requests = requests.boxed().fuse(); - async fn next<'a, Context: crate::Context>( - server: &'a Server, - running: &mut RunningCommands, - requests: &mut Fuse>>, - ) -> Result, RpcError> { - loop { - tokio::select! { - req = requests.try_next() => { - let req = req?; - if let Some(req) = req { - running.running.push(tokio::spawn(server.handle(req))); - } else { - running.closed = true; - } - } - res = running.try_next() => { - return res; - } - } - } - } + ) -> impl Stream> + 'a { async_stream::try_stream! { - while let Some(res) = next(self, &mut running, &mut requests).await? { + let mut runner = JobRunner::new(); + let requests = requests.fuse().map(|req| self.handle(req)); + tokio::pin!(requests); + + while let Some(res) = runner.next_result(&mut requests).await.transpose()? { yield res; } } } } - -#[derive(Default)] -struct RunningCommands { - closed: bool, - running: Vec>>, -} - -impl Stream for RunningCommands { - type Item = Result; - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let item = self - .running - .iter_mut() - .enumerate() - .find_map(|(i, f)| match f.poll_unpin(cx) { - std::task::Poll::Pending => None, - std::task::Poll::Ready(e) => Some(( - i, - e.map_err(|e| RpcError { - data: Some(e.to_string().into()), - ..yajrc::INTERNAL_ERROR - }) - .and_then(|a| a), - )), - }); - match item { - Some((idx, res)) => { - drop(self.running.swap_remove(idx)); - std::task::Poll::Ready(Some(res)) - } - None => { - if !self.closed || !self.running.is_empty() { - std::task::Poll::Pending - } else { - std::task::Poll::Ready(None) - } - } - } - } -} -impl Drop for RunningCommands { - fn drop(&mut self) { - for hdl in &self.running { - hdl.abort(); - } - if let Ok(rt) = Handle::try_current() { - rt.block_on(join_all(std::mem::take(&mut self.running).into_iter())); - } - } -} diff --git a/rpc-toolkit/src/server/socket.rs b/rpc-toolkit/src/server/socket.rs index 42e8aa3..e42b42c 100644 --- a/rpc-toolkit/src/server/socket.rs +++ b/rpc-toolkit/src/server/socket.rs @@ -1,25 +1,95 @@ -use futures::{AsyncWrite, Future, Stream}; -use tokio::io::AsyncRead; -use tokio::sync::oneshot; +use std::path::Path; +use std::sync::Arc; + +use futures::{Future, Stream, StreamExt, TryStreamExt}; +use imbl_value::Value; +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; +use tokio::net::{TcpListener, ToSocketAddrs, UnixListener}; +use tokio::sync::OnceCell; use yajrc::RpcError; +use crate::util::{parse_error, JobRunner}; use crate::Server; -pub struct ShutdownHandle(oneshot::Sender<()>); - -pub struct SocketServer { - server: Server, -} -impl SocketServer { - pub fn run_json( - &self, - listener: impl Stream, - ) -> (ShutdownHandle, impl Future>) { - let (shutdown_send, shutdown_recv) = oneshot::channel(); - (ShutdownHandle(shutdown_send), async move { - //asdf - //adf - Ok(()) - }) +#[derive(Clone)] +pub struct ShutdownHandle(Arc>); +impl ShutdownHandle { + pub fn shutdown(self) { + let _ = self.0.set(()); + } +} + +impl Server { + pub fn run_socket<'a, T: AsyncRead + AsyncWrite + Send>( + &'a self, + listener: impl Stream> + 'a, + error_handler: impl Fn(std::io::Error) + Sync + 'a, + ) -> (ShutdownHandle, impl Future + 'a) { + let shutdown = Arc::new(OnceCell::new()); + (ShutdownHandle(shutdown.clone()), async move { + let mut runner = JobRunner::>::new(); + let jobs = listener.map(|pipe| async { + let pipe = pipe?; + let (r, mut w) = tokio::io::split(pipe); + let stream = self.stream( + tokio_stream::wrappers::LinesStream::new(BufReader::new(r).lines()) + .map_err(|e| RpcError { + data: Some(e.to_string().into()), + ..yajrc::INTERNAL_ERROR + }) + .try_filter_map(|a| async move { + Ok(if a.is_empty() { + None + } else { + Some(serde_json::from_str::(&a).map_err(parse_error)?) + }) + }), + ); + tokio::pin!(stream); + while let Some(res) = stream.next().await { + if let Err(e) = async { + let mut buf = serde_json::to_vec( + &res.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?, + ) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + buf.push(b'\n'); + w.write_all(&buf).await + } + .await + { + error_handler(e) + } + } + Ok(()) + }); + tokio::pin!(jobs); + while let Some(res) = runner.next_result(&mut jobs).await { + if let Err(e) = res { + error_handler(e) + } + } + }) + } + pub fn run_unix<'a>( + &'a self, + path: impl AsRef + 'a, + error_handler: impl Fn(std::io::Error) + Sync + 'a, + ) -> std::io::Result<(ShutdownHandle, impl Future + 'a)> { + let listener = UnixListener::bind(path)?; + Ok(self.run_socket( + tokio_stream::wrappers::UnixListenerStream::new(listener), + error_handler, + )) + } + pub async fn run_tcp<'a>( + &'a self, + addr: impl ToSocketAddrs + 'a, + error_handler: impl Fn(std::io::Error) + Sync + 'a, + ) -> std::io::Result<(ShutdownHandle, impl Future + 'a)> { + let listener = TcpListener::bind(addr).await?; + Ok(self.run_socket( + tokio_stream::wrappers::TcpListenerStream::new(listener), + error_handler, + )) } } diff --git a/rpc-toolkit/src/util.rs b/rpc-toolkit/src/util.rs index b56cb47..00c3540 100644 --- a/rpc-toolkit/src/util.rs +++ b/rpc-toolkit/src/util.rs @@ -1,3 +1,7 @@ +use std::fmt::Display; + +use futures::future::BoxFuture; +use futures::{Future, FutureExt, Stream, StreamExt}; use imbl_value::Value; use serde::de::DeserializeOwned; use serde::Deserialize; @@ -42,13 +46,20 @@ pub fn invalid_request(e: imbl_value::Error) -> RpcError { } } -pub fn parse_error(e: imbl_value::Error) -> RpcError { +pub fn parse_error(e: impl Display) -> RpcError { RpcError { data: Some(e.to_string().into()), ..yajrc::PARSE_ERROR } } +pub fn internal_error(e: impl Display) -> RpcError { + RpcError { + data: Some(e.to_string().into()), + ..yajrc::INTERNAL_ERROR + } +} + pub struct Flat(pub A, pub B); impl<'de, A, B> Deserialize<'de> for Flat where @@ -65,3 +76,72 @@ where Ok(Flat(a, b)) } } + +pub fn poll_select_all<'a, T>( + futs: &mut Vec>, + cx: &mut std::task::Context<'_>, +) -> std::task::Poll { + let item = futs + .iter_mut() + .enumerate() + .find_map(|(i, f)| match f.poll_unpin(cx) { + std::task::Poll::Pending => None, + std::task::Poll::Ready(e) => Some((i, e)), + }); + match item { + Some((idx, res)) => { + drop(futs.swap_remove(idx)); + std::task::Poll::Ready(res) + } + None => std::task::Poll::Pending, + } +} + +pub struct JobRunner<'a, T> { + closed: bool, + running: Vec>, +} +impl<'a, T> JobRunner<'a, T> { + pub fn new() -> Self { + JobRunner { + closed: false, + running: Vec::new(), + } + } + pub async fn next_result< + Src: Stream + Unpin, + Fut: Future + Send + 'a, + >( + &mut self, + job_source: &mut Src, + ) -> Option { + loop { + tokio::select! { + job = job_source.next() => { + if let Some(job) = job { + self.running.push(job.boxed()); + } else { + self.closed = true; + } + } + res = self.next() => { + return res; + } + } + } + } +} +impl<'a, T> Stream for JobRunner<'a, T> { + type Item = T; + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match poll_select_all(&mut self.running, cx) { + std::task::Poll::Pending if self.closed && self.running.is_empty() => { + std::task::Poll::Ready(None) + } + a => a.map(Some), + } + } +} diff --git a/rpc-toolkit/tests/compat.rs b/rpc-toolkit/tests/compat.rs index c3d59b1..a96d00c 100644 --- a/rpc-toolkit/tests/compat.rs +++ b/rpc-toolkit/tests/compat.rs @@ -1,214 +1,205 @@ -use std::fmt::Display; -use std::str::FromStr; -use std::sync::Arc; +// use std::fmt::Display; +// use std::str::FromStr; +// use std::sync::Arc; -use futures::FutureExt; -use hyper::Request; -use rpc_toolkit::clap::Arg; -use rpc_toolkit::hyper::http::Error as HttpError; -use rpc_toolkit::hyper::{Body, Response}; -use rpc_toolkit::rpc_server_helpers::{ - DynMiddlewareStage2, DynMiddlewareStage3, DynMiddlewareStage4, -}; -use rpc_toolkit::serde::{Deserialize, Serialize}; -use rpc_toolkit::url::Host; -use rpc_toolkit::yajrc::RpcError; -use rpc_toolkit::{command, rpc_server, run_cli, Context, Metadata}; +// use futures::FutureExt; +// use hyper::Request; +// use rpc_toolkit::clap::Arg; +// use rpc_toolkit::hyper::http::Error as HttpError; +// use rpc_toolkit::hyper::{Body, Response}; +// use rpc_toolkit::rpc_server_helpers::{ +// DynMiddlewareStage2, DynMiddlewareStage3, DynMiddlewareStage4, +// }; +// use rpc_toolkit::serde::{Deserialize, Serialize}; +// use rpc_toolkit::url::Host; +// use rpc_toolkit::yajrc::RpcError; +// use rpc_toolkit::{command, rpc_server, run_cli, Context, Metadata}; -#[derive(Debug, Clone)] -pub struct AppState(Arc); -impl From for () { - fn from(_: AppState) -> Self { - () - } -} +// #[derive(Debug, Clone)] +// pub struct AppState(Arc); +// impl From for () { +// fn from(_: AppState) -> Self { +// () +// } +// } -#[derive(Debug)] -pub struct ConfigSeed { - host: Host, - port: u16, -} +// #[derive(Debug)] +// pub struct ConfigSeed { +// host: Host, +// port: u16, +// } -impl Context for AppState { - fn host(&self) -> Host<&str> { - match &self.0.host { - Host::Domain(s) => Host::Domain(s.as_str()), - Host::Ipv4(i) => Host::Ipv4(*i), - Host::Ipv6(i) => Host::Ipv6(*i), - } - } - fn port(&self) -> u16 { - self.0.port - } -} +// impl Context for AppState { +// type Metadata = (); +// } -fn test_string() -> String { - "test".to_owned() -} +// fn test_string() -> String { +// "test".to_owned() +// } -#[command( - about = "Does the thing", - subcommands("dothething2::", self(dothething_impl(async))) -)] -async fn dothething< - U: Serialize + for<'a> Deserialize<'a> + FromStr + Clone + 'static, - E: Display, ->( - #[context] _ctx: AppState, - #[arg(short = 'a')] arg1: Option, - #[arg(short = 'b', default = "test_string")] val: String, - #[arg(short = 'c', help = "I am the flag `c`!", default)] arg3: bool, - #[arg(stdin)] structured: U, -) -> Result<(Option, String, bool, U), RpcError> { - Ok((arg1, val, arg3, structured)) -} +// #[command( +// about = "Does the thing", +// subcommands("dothething2::", self(dothething_impl(async))) +// )] +// async fn dothething< +// U: Serialize + for<'a> Deserialize<'a> + FromStr + Clone + 'static, +// E: Display, +// >( +// #[context] _ctx: AppState, +// #[arg(short = 'a')] arg1: Option, +// #[arg(short = 'b', default = "test_string")] val: String, +// #[arg(short = 'c', help = "I am the flag `c`!", default)] arg3: bool, +// #[arg(stdin)] structured: U, +// ) -> Result<(Option, String, bool, U), RpcError> { +// Ok((arg1, val, arg3, structured)) +// } -async fn dothething_impl( - ctx: AppState, - parent_data: (Option, String, bool, U), -) -> Result { - Ok(format!( - "{:?}, {:?}, {}, {}, {}", - ctx, - parent_data.0, - parent_data.1, - parent_data.2, - serde_json::to_string_pretty(&parent_data.3)? - )) -} +// async fn dothething_impl( +// ctx: AppState, +// parent_data: (Option, String, bool, U), +// ) -> Result { +// Ok(format!( +// "{:?}, {:?}, {}, {}, {}", +// ctx, +// parent_data.0, +// parent_data.1, +// parent_data.2, +// serde_json::to_string_pretty(&parent_data.3)? +// )) +// } -#[command(about = "Does the thing")] -fn dothething2 Deserialize<'a> + FromStr, E: Display>( - #[parent_data] parent_data: (Option, String, bool, U), - #[arg(stdin)] structured2: U, -) -> Result { - Ok(format!( - "{:?}, {}, {}, {}, {}", - parent_data.0, - parent_data.1, - parent_data.2, - serde_json::to_string_pretty(&parent_data.3)?, - serde_json::to_string_pretty(&structured2)?, - )) -} +// #[command(about = "Does the thing")] +// fn dothething2 Deserialize<'a> + FromStr, E: Display>( +// #[parent_data] parent_data: (Option, String, bool, U), +// #[arg(stdin)] structured2: U, +// ) -> Result { +// Ok(format!( +// "{:?}, {}, {}, {}, {}", +// parent_data.0, +// parent_data.1, +// parent_data.2, +// serde_json::to_string_pretty(&parent_data.3)?, +// serde_json::to_string_pretty(&structured2)?, +// )) +// } -async fn cors( - req: &mut Request, - _: M, -) -> Result>, HttpError> { - if req.method() == hyper::Method::OPTIONS { - Ok(Err(Response::builder() - .header("Access-Control-Allow-Origin", "*") - .body(Body::empty())?)) - } else { - Ok(Ok(Box::new(|_, _| { - async move { - let res: DynMiddlewareStage3 = Box::new(|_, _| { - async move { - let res: DynMiddlewareStage4 = Box::new(|res| { - async move { - res.headers_mut() - .insert("Access-Control-Allow-Origin", "*".parse()?); - Ok::<_, HttpError>(()) - } - .boxed() - }); - Ok::<_, HttpError>(Ok(res)) - } - .boxed() - }); - Ok::<_, HttpError>(Ok(res)) - } - .boxed() - }))) - } -} +// async fn cors( +// req: &mut Request, +// _: M, +// ) -> Result>, HttpError> { +// if req.method() == hyper::Method::OPTIONS { +// Ok(Err(Response::builder() +// .header("Access-Control-Allow-Origin", "*") +// .body(Body::empty())?)) +// } else { +// Ok(Ok(Box::new(|_, _| { +// async move { +// let res: DynMiddlewareStage3 = Box::new(|_, _| { +// async move { +// let res: DynMiddlewareStage4 = Box::new(|res| { +// async move { +// res.headers_mut() +// .insert("Access-Control-Allow-Origin", "*".parse()?); +// Ok::<_, HttpError>(()) +// } +// .boxed() +// }); +// Ok::<_, HttpError>(Ok(res)) +// } +// .boxed() +// }); +// Ok::<_, HttpError>(Ok(res)) +// } +// .boxed() +// }))) +// } +// } -#[tokio::test] -async fn test_rpc() { - use tokio::io::AsyncWriteExt; +// #[tokio::test] +// async fn test_rpc() { +// use tokio::io::AsyncWriteExt; - let seed = Arc::new(ConfigSeed { - host: Host::parse("localhost").unwrap(), - port: 8000, - }); - let server = rpc_server!({ - command: dothething::, - context: AppState(seed), - middleware: [ - cors, - ], - }); - let handle = tokio::spawn(server); - let mut cmd = tokio::process::Command::new("cargo") - .arg("test") - .arg("--package") - .arg("rpc-toolkit") - .arg("--test") - .arg("test") - .arg("--") - .arg("cli_test") - .arg("--exact") - .arg("--nocapture") - .arg("--") - // .arg("-b") - // .arg("test") - .arg("dothething2") - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .spawn() - .unwrap(); - cmd.stdin - .take() - .unwrap() - .write_all(b"TEST\nHAHA") - .await - .unwrap(); - let out = cmd.wait_with_output().await.unwrap(); - assert!(out.status.success()); - assert!(dbg!(std::str::from_utf8(&out.stdout).unwrap()) - .contains("\nNone, test, false, \"TEST\", \"HAHA\"\n")); - handle.abort(); -} +// let seed = Arc::new(ConfigSeed { +// host: Host::parse("localhost").unwrap(), +// port: 8000, +// }); +// let server = rpc_server!({ +// command: dothething::, +// context: AppState(seed), +// middleware: [ +// cors, +// ], +// }); +// let handle = tokio::spawn(server); +// let mut cmd = tokio::process::Command::new("cargo") +// .arg("test") +// .arg("--package") +// .arg("rpc-toolkit") +// .arg("--test") +// .arg("test") +// .arg("--") +// .arg("cli_test") +// .arg("--exact") +// .arg("--nocapture") +// .arg("--") +// // .arg("-b") +// // .arg("test") +// .arg("dothething2") +// .stdin(std::process::Stdio::piped()) +// .stdout(std::process::Stdio::piped()) +// .spawn() +// .unwrap(); +// cmd.stdin +// .take() +// .unwrap() +// .write_all(b"TEST\nHAHA") +// .await +// .unwrap(); +// let out = cmd.wait_with_output().await.unwrap(); +// assert!(out.status.success()); +// assert!(dbg!(std::str::from_utf8(&out.stdout).unwrap()) +// .contains("\nNone, test, false, \"TEST\", \"HAHA\"\n")); +// handle.abort(); +// } -#[test] -fn cli_test() { - let app = dothething::build_app(); - let mut skip = true; - let args = std::iter::once(std::ffi::OsString::from("cli_test")) - .chain(std::env::args_os().into_iter().skip_while(|a| { - if a == "--" { - skip = false; - return true; - } - skip - })) - .collect::>(); - if skip { - return; - } - let matches = app.get_matches_from(args); - let seed = Arc::new(ConfigSeed { - host: Host::parse("localhost").unwrap(), - port: 8000, - }); - dothething::cli_handler::(AppState(seed), (), None, &matches, "".into(), ()) - .unwrap(); -} +// #[test] +// fn cli_test() { +// let app = dothething::build_app(); +// let mut skip = true; +// let args = std::iter::once(std::ffi::OsString::from("cli_test")) +// .chain(std::env::args_os().into_iter().skip_while(|a| { +// if a == "--" { +// skip = false; +// return true; +// } +// skip +// })) +// .collect::>(); +// if skip { +// return; +// } +// let matches = app.get_matches_from(args); +// let seed = Arc::new(ConfigSeed { +// host: Host::parse("localhost").unwrap(), +// port: 8000, +// }); +// dothething::cli_handler::(AppState(seed), (), None, &matches, "".into(), ()) +// .unwrap(); +// } -#[test] -#[ignore] -fn cli_example() { - run_cli! ({ - command: dothething::, - app: app => app - .arg(Arg::with_name("host").long("host").short('h').takes_value(true)) - .arg(Arg::with_name("port").long("port").short('p').takes_value(true)), - context: matches => AppState(Arc::new(ConfigSeed { - host: Host::parse(matches.value_of("host").unwrap_or("localhost")).unwrap(), - port: matches.value_of("port").unwrap_or("8000").parse().unwrap(), - })) - }) -} +// #[test] +// #[ignore] +// fn cli_example() { +// run_cli! ({ +// command: dothething::, +// app: app => app +// .arg(Arg::with_name("host").long("host").short('h').takes_value(true)) +// .arg(Arg::with_name("port").long("port").short('p').takes_value(true)), +// context: matches => AppState(Arc::new(ConfigSeed { +// host: Host::parse(matches.value_of("host").unwrap_or("localhost")).unwrap(), +// port: matches.value_of("port").unwrap_or("8000").parse().unwrap(), +// })) +// }) +// } -//////////////////////////////////////////////// +// //////////////////////////////////////////////// diff --git a/rpc-toolkit/tests/test.rs b/rpc-toolkit/tests/test.rs index 34922a1..8c8b206 100644 --- a/rpc-toolkit/tests/test.rs +++ b/rpc-toolkit/tests/test.rs @@ -1 +1,109 @@ -pub struct App; +use std::path::PathBuf; + +use clap::Parser; +use futures::Future; +use rpc_toolkit::{ + AsyncCommand, CliContextSocket, Command, Contains, Context, DynCommand, LeafCommand, NoParent, + ParentCommand, ParentInfo, Server, ShutdownHandle, +}; +use serde::{Deserialize, Serialize}; +use tokio::net::UnixStream; +use yajrc::RpcError; + +struct ServerContext; +impl Context for ServerContext { + type Metadata = (); +} + +struct CliContext(PathBuf); +impl Context for CliContext { + type Metadata = (); +} +#[async_trait::async_trait] +impl CliContextSocket for CliContext { + type Stream = UnixStream; + async fn connect(&self) -> std::io::Result { + UnixStream::connect(&self.0).await + } +} +#[async_trait::async_trait] +impl rpc_toolkit::CliContext for CliContext { + async fn call_remote( + &self, + method: &str, + params: imbl_value::Value, + ) -> Result { + ::call_remote(self, method, params).await + } +} + +async fn run_server() { + Server::new( + vec![ + DynCommand::from_parent::(Contains::none()), + DynCommand::from_async::(Contains::none()), + // DynCommand::from_async::(Contains::none()), + // DynCommand::from_sync::(Contains::none()), + // DynCommand::from_sync::(Contains::none()), + ], + || async { Ok(ServerContext) }, + ) + .run_unix("./test.sock", |e| eprintln!("{e}")) + .unwrap() + .1 + .await +} + +#[derive(Debug, Deserialize, Serialize, Parser)] +struct Group { + #[arg(short, long)] + verbose: bool, +} +impl Command for Group { + const NAME: &'static str = "group"; + type Parent = NoParent; +} +impl ParentCommand for Group +where + Ctx: Context, + // SubThing: AsyncCommand, + Thing1: AsyncCommand, +{ + fn subcommands(chain: rpc_toolkit::ParentChain) -> Vec> { + vec![ + // DynCommand::from_async::(chain.child()), + DynCommand::from_async::(Contains::none()), + ] + } +} + +#[derive(Debug, Deserialize, Serialize, Parser)] +struct Thing1 { + thing: String, +} +impl Command for Thing1 { + const NAME: &'static str = "thing1"; + type Parent = NoParent; +} +impl LeafCommand for Thing1 { + type Ok = String; + type Err = RpcError; + fn display(self, _: ServerContext, _: rpc_toolkit::ParentInfo, res: Self::Ok) { + println!("{}", res); + } +} +#[async_trait::async_trait] +impl AsyncCommand for Thing1 { + async fn implementation( + self, + _: ServerContext, + _: ParentInfo, + ) -> Result { + Ok(format!("Thing1 is {}", self.thing)) + } +} + +#[tokio::test] +async fn test() { + let server = tokio::spawn(run_server()); +}