diff --git a/rpc-toolkit/src/cli.rs b/rpc-toolkit/src/cli.rs new file mode 100644 index 0000000..c2d2aa4 --- /dev/null +++ b/rpc-toolkit/src/cli.rs @@ -0,0 +1,108 @@ +use clap::ArgMatches; +use imbl_value::Value; +use reqwest::{Client, Method}; +use serde::de::DeserializeOwned; +use serde::Serialize; +use url::Url; +use yajrc::{GenericRpcMethod, Id, RpcError, RpcRequest}; + +use crate::command::{AsyncCommand, DynCommand, LeafCommand, ParentInfo}; +use crate::util::{combine, invalid_params, parse_error}; +use crate::ParentChain; + +pub struct CliApp { + pub(crate) command: DynCommand, + pub(crate) make_ctx: Box Result>, +} + +#[async_trait::async_trait] +pub trait CliContext { + async fn call_remote(&self, method: &str, params: Value) -> Result; +} + +pub trait CliContextHttp { + fn client(&self) -> &Client; + fn url(&self) -> Url; +} +#[async_trait::async_trait] +impl CliContext for T { + async fn call_remote(&self, method: &str, params: Value) -> Result { + let rpc_req: RpcRequest> = RpcRequest { + id: Some(Id::Number(0.into())), + method: GenericRpcMethod::new(method), + params, + }; + let mut req = self.client().request(Method::POST, self.url()); + let body; + #[cfg(feature = "cbor")] + { + req = req.header("content-type", "application/cbor"); + req = req.header("accept", "application/cbor, application/json"); + body = serde_cbor::to_vec(&rpc_req)?; + } + #[cfg(not(feature = "cbor"))] + { + req = req.header("content-type", "application/json"); + req = req.header("accept", "application/json"); + body = serde_json::to_vec(&req)?; + } + let res = req + .header("content-length", body.len()) + .body(body) + .send() + .await?; + Ok( + match res + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + { + Some("application/json") => serde_json::from_slice(&*res.bytes().await?)?, + #[cfg(feature = "cbor")] + Some("application/cbor") => serde_cbor::from_slice(&*res.bytes().await?)?, + _ => { + return Err(RpcError { + data: Some("missing content type".into()), + ..yajrc::INTERNAL_ERROR + }) + } + }, + ) + } +} + +pub trait RemoteCommand: LeafCommand { + fn subcommands(chain: ParentChain) -> Vec> { + drop(chain); + Vec::new() + } +} +#[async_trait::async_trait] +impl AsyncCommand for T +where + T: RemoteCommand + Send + Serialize, + T::Parent: Serialize, + T::Ok: DeserializeOwned, + T::Err: From, + Context: CliContext + Send + 'static, +{ + async fn implementation( + self, + ctx: Context, + parent: ParentInfo, + ) -> Result { + let mut method = parent.method; + method.push(Self::NAME); + Ok(imbl_value::from_value( + ctx.call_remote( + &method.join("."), + combine( + imbl_value::to_value(&self).map_err(invalid_params)?, + imbl_value::to_value(&parent.args).map_err(invalid_params)?, + )?, + ) + .await?, + ) + .map_err(parse_error)?) + } +} diff --git a/rpc-toolkit/src/command.rs b/rpc-toolkit/src/command.rs index e49d9f8..e6147cd 100644 --- a/rpc-toolkit/src/command.rs +++ b/rpc-toolkit/src/command.rs @@ -10,17 +10,22 @@ use serde::ser::Serialize; use tokio::runtime::Runtime; use yajrc::RpcError; +use crate::util::{combine, extract, Flat}; + +/// Stores a command's implementation for a given context +/// Can be created from anything that implements ParentCommand, AsyncCommand, or SyncCommand pub struct DynCommand { - name: &'static str, - implementation: Option>, - cli: Option, - subcommands: Vec, + pub(crate) name: &'static str, + pub(crate) implementation: Option>, + pub(crate) cli: Option, + pub(crate) subcommands: Vec, } impl DynCommand { - fn cli_app(&self) -> Option { + pub(crate) fn cli_app(&self) -> Option { if let Some(cli) = &self.cli { Some( cli.cmd + .clone() .name(self.name) .subcommands(self.subcommands.iter().filter_map(|c| c.cli_app())), ) @@ -28,7 +33,7 @@ impl DynCommand { None } } - fn impl_from_cli_matches( + pub(crate) fn impl_from_cli_matches( &self, matches: &ArgMatches, parent: Value, @@ -53,12 +58,13 @@ impl DynCommand { Err(yajrc::METHOD_NOT_FOUND_ERROR) } } - pub fn run_cli(ctx: Context) {} } struct Implementation { - async_impl: Arc BoxFuture<'static, Result>>, - sync_impl: Arc Result>, + pub(crate) async_impl: Arc< + dyn Fn(Context, Vec<&'static str>, Value) -> BoxFuture<'static, Result>, + >, + pub(crate) sync_impl: Arc, Value) -> Result>, } impl Clone for Implementation { fn clone(&self) -> Self { @@ -72,7 +78,7 @@ impl Clone for Implementation { struct CliBindings { cmd: clap::Command, parser: Box Fn(&'a ArgMatches) -> Result + Send + Sync>, - display: Option>, + display: Option Result<(), imbl_value::Error> + Send + Sync>>, } impl CliBindings { fn from_parent() -> Self { @@ -93,36 +99,50 @@ impl CliBindings { } fn from_leaf() -> Self { Self { - display: Some(Box::new(|res| Cmd::display(todo!("{}", res)))), + display: Some(Box::new(|res| { + Ok(Cmd::display(imbl_value::from_value(res)?)) + })), ..Self::from_parent::() } } } -pub trait Command: DeserializeOwned + Sized { +/// Must be implemented for all commands +/// Use `Parent = NoParent` if the implementation requires no arguments from the parent command +pub trait Command: DeserializeOwned + Sized + Send { const NAME: &'static str; type Parent: Command; } +/// Includes the parent method, and the arguments requested from the parent +/// Arguments are flattened out in the params object, so ensure that there are no collisions between the names of the arguments for your method and its parents +pub struct ParentInfo { + pub method: Vec<&'static str>, + pub args: T, +} + +/// This is automatically generated from a command based on its Parents. +/// It can be used to generate a proof that one of the parents contains the necessary arguments that a subcommand requires. pub struct ParentChain(PhantomData); pub struct Contains(PhantomData); -impl From<(Contains, Contains)> for Contains<(T, U)> { - fn from(value: (Contains, Contains)) -> Self { +impl From<(Contains, Contains)> for Contains> { + fn from(_: (Contains, Contains)) -> Self { Self(PhantomData) } } +/// Use this as a Parent if your command does not require any arguments from its parents #[derive(serde::Deserialize, serde::Serialize)] -pub struct Root {} -impl Command for Root { +pub struct NoParent {} +impl Command for NoParent { const NAME: &'static str = ""; - type Parent = Root; + type Parent = NoParent; } impl ParentChain where Cmd: Command, { - pub fn unit(&self) -> Contains<()> { + pub fn none(&self) -> Contains { Contains(PhantomData) } pub fn child(&self) -> Contains { @@ -133,6 +153,7 @@ where } } +/// Implement this for a command that has no implementation, but simply exists to organize subcommands pub trait ParentCommand: Command { fn subcommands(chain: ParentChain) -> Vec>; } @@ -147,34 +168,52 @@ impl DynCommand { subcommands: Cmd::subcommands(ParentChain::(PhantomData)), } } + pub fn from_parent_no_cli>() -> Self { + Self { + name: Cmd::NAME, + implementation: None, + cli: None, + subcommands: Cmd::subcommands(ParentChain::(PhantomData)), + } + } } +/// Implement this for any command with an implementation pub trait LeafCommand: Command { - type Ok: Serialize; - type Err: Into; + type Ok: DeserializeOwned + Serialize + Send; + type Err: From + Into + Send; fn display(res: Self::Ok); } +/// Implement this if your Command's implementation is async #[async_trait::async_trait] pub trait AsyncCommand: LeafCommand { async fn implementation( self, ctx: Context, - parent: Self::Parent, + parent: ParentInfo, ) -> Result; fn subcommands(chain: ParentChain) -> Vec> { + drop(chain); Vec::new() } } -impl Implementation { +impl Implementation { fn for_async>(contains: Contains) -> Self { + drop(contains); Self { - async_impl: Arc::new(|ctx, params| { + async_impl: Arc::new(|ctx, method, params| { async move { let parent = extract::(¶ms)?; imbl_value::to_value( &extract::(¶ms)? - .implementation(ctx, parent) + .implementation( + ctx, + ParentInfo { + method, + args: parent, + }, + ) .await .map_err(|e| e.into())?, ) @@ -185,7 +224,7 @@ impl Implementation { } .boxed() }), - sync_impl: Arc::new(|ctx, params| { + sync_impl: Arc::new(|ctx, method, params| { let parent = extract::(¶ms)?; imbl_value::to_value( &Runtime::new() @@ -196,7 +235,13 @@ impl Implementation { data: Some(e.to_string().into()), ..yajrc::INVALID_PARAMS_ERROR })? - .implementation(ctx, parent), + .implementation( + ctx, + ParentInfo { + method, + args: parent, + }, + ), ) .map_err(|e| e.into())?, ) @@ -208,7 +253,7 @@ impl Implementation { } } } -impl DynCommand { +impl DynCommand { pub fn from_async + FromArgMatches + CommandFactory + Serialize>( contains: Contains, ) -> Self { @@ -219,20 +264,35 @@ impl DynCommand { subcommands: Cmd::subcommands(ParentChain::(PhantomData)), } } + pub fn from_async_no_cli>(contains: Contains) -> Self { + Self { + name: Cmd::NAME, + implementation: Some(Implementation::for_async::(contains)), + cli: None, + subcommands: Cmd::subcommands(ParentChain::(PhantomData)), + } + } } +/// Implement this if your Command's implementation is not async pub trait SyncCommand: LeafCommand { const BLOCKING: bool; - fn implementation(self, ctx: Context, parent: Self::Parent) -> Result; + fn implementation( + self, + ctx: Context, + parent: ParentInfo, + ) -> Result; fn subcommands(chain: ParentChain) -> Vec> { + drop(chain); Vec::new() } } -impl Implementation { +impl Implementation { fn for_sync>(contains: Contains) -> Self { + drop(contains); Self { async_impl: if Cmd::BLOCKING { - Arc::new(|ctx, params| { + Arc::new(|ctx, method, params| { tokio::task::spawn_blocking(move || { let parent = extract::(¶ms)?; imbl_value::to_value( @@ -241,7 +301,13 @@ impl Implementation { data: Some(e.to_string().into()), ..yajrc::INVALID_PARAMS_ERROR })? - .implementation(ctx, parent) + .implementation( + ctx, + ParentInfo { + method, + args: parent, + }, + ) .map_err(|e| e.into())?, ) .map_err(|e| RpcError { @@ -258,7 +324,7 @@ impl Implementation { .boxed() }) } else { - Arc::new(|ctx, params| { + Arc::new(|ctx, method, params| { async move { let parent = extract::(¶ms)?; imbl_value::to_value( @@ -267,7 +333,13 @@ impl Implementation { data: Some(e.to_string().into()), ..yajrc::INVALID_PARAMS_ERROR })? - .implementation(ctx, parent) + .implementation( + ctx, + ParentInfo { + method, + args: parent, + }, + ) .map_err(|e| e.into())?, ) .map_err(|e| RpcError { @@ -278,7 +350,7 @@ impl Implementation { .boxed() }) }, - sync_impl: Arc::new(|ctx, params| { + sync_impl: Arc::new(|ctx, method, params| { let parent = extract::(¶ms)?; imbl_value::to_value( &extract::(¶ms) @@ -286,7 +358,13 @@ impl Implementation { data: Some(e.to_string().into()), ..yajrc::INVALID_PARAMS_ERROR })? - .implementation(ctx, parent) + .implementation( + ctx, + ParentInfo { + method, + args: parent, + }, + ) .map_err(|e| e.into())?, ) .map_err(|e| RpcError { @@ -297,7 +375,7 @@ impl Implementation { } } } -impl DynCommand { +impl DynCommand { pub fn from_sync + FromArgMatches + CommandFactory + Serialize>( contains: Contains, ) -> Self { @@ -308,29 +386,12 @@ impl DynCommand { subcommands: Cmd::subcommands(ParentChain::(PhantomData)), } } -} - -fn extract(value: &Value) -> Result { - imbl_value::from_value(value.clone()).map_err(|e| RpcError { - data: Some(e.to_string().into()), - ..yajrc::INVALID_PARAMS_ERROR - }) -} - -fn combine(v1: Value, v2: Value) -> Result { - let (Value::Object(mut v1), Value::Object(v2)) = (v1, v2) else { - return Err(RpcError { - data: Some("params must be object".into()), - ..yajrc::INVALID_PARAMS_ERROR - }); - }; - for (key, value) in v2 { - if v1.insert(key.clone(), value).is_some() { - return Err(RpcError { - data: Some(format!("duplicate key: {key}").into()), - ..yajrc::INVALID_PARAMS_ERROR - }); + pub fn from_sync_no_cli>(contains: Contains) -> Self { + Self { + name: Cmd::NAME, + implementation: Some(Implementation::for_sync::(contains)), + cli: None, + subcommands: Cmd::subcommands(ParentChain::(PhantomData)), } } - Ok(Value::Object(v1)) } diff --git a/rpc-toolkit/src/lib.rs b/rpc-toolkit/src/lib.rs index 89be369..1613002 100644 --- a/rpc-toolkit/src/lib.rs +++ b/rpc-toolkit/src/lib.rs @@ -1,3 +1,5 @@ +pub use cli::*; +pub use command::*; /// `#[command(...)]` /// - `#[command(cli_only)]` -> executed by CLI instead of RPC server (leaf commands only) /// - `#[command(rpc_only)]` -> no CLI bindings (leaf commands only) @@ -20,40 +22,12 @@ /// /// See also: [arg](rpc_toolkit_macro::arg), [context](rpc_toolkit_macro::context) pub use rpc_toolkit_macro::command; -/// `rpc_handler!(command, context, status_fn)` -/// - returns: [RpcHandler](rpc_toolkit::RpcHandler) -/// - `command`: path to an rpc command (with the `#[command]` attribute) -/// - `context`: The [Context] for `command`. Must implement [Clone](std::clone::Clone). -/// - `status_fn` (optional): a function that takes a JSON RPC error code (`i32`) and returns a [StatusCode](hyper::StatusCode) -/// - default: `|_| StatusCode::OK` -pub use rpc_toolkit_macro::rpc_handler; -/// `rpc_server!(command, context, status_fn)` -/// - returns: [Server](hyper::Server) -/// - `command`: path to an rpc command (with the `#[command]` attribute) -/// - `context`: The [Context] for `command`. Must implement [Clone](std::clone::Clone). -/// - `status_fn` (optional): a function that takes a JSON RPC error code (`i32`) and returns a [StatusCode](hyper::StatusCode) -/// - default: `|_| StatusCode::OK` -pub use rpc_toolkit_macro::rpc_server; -/// `run_cli!(command, app_mutator, make_ctx, exit_fn)` -/// - this function does not return -/// - `command`: path to an rpc command (with the `#[command]` attribute) -/// - `app_mutator` (optional): an expression that returns a mutated app. -/// - example: `app => app.arg(Arg::with_name("port").long("port"))` -/// - default: `app => app` -/// - `make_ctx` (optional): an expression that takes [&ArgMatches](clap::ArgMatches) and returns the [Context] used by `command`. -/// - example: `matches => matches.value_of("port")` -/// - default: `matches => matches` -/// - `exit_fn` (optional): a function that takes a JSON RPC error code (`i32`) and returns an Exit code (`i32`) -/// - default: `|code| code` -pub use rpc_toolkit_macro::run_cli; pub use {clap, futures, hyper, reqwest, serde, serde_json, tokio, url, yajrc}; -pub use crate::context::Context; -pub use crate::metadata::Metadata; -pub use crate::rpc_server_helpers::RpcHandler; - -mod command; -pub mod command_helpers; -mod context; -mod metadata; -pub mod rpc_server_helpers; +pub(crate) mod cli; +pub(crate) mod command; +// pub mod command_helpers; +// mod context; +// mod metadata; +// pub mod rpc_server_helpers; +pub(crate) mod util; diff --git a/rpc-toolkit/src/util.rs b/rpc-toolkit/src/util.rs new file mode 100644 index 0000000..eccfad9 --- /dev/null +++ b/rpc-toolkit/src/util.rs @@ -0,0 +1,60 @@ +use imbl_value::Value; +use serde::de::DeserializeOwned; +use serde::Deserialize; +use yajrc::RpcError; + +pub fn extract(value: &Value) -> Result { + imbl_value::from_value(value.clone()).map_err(|e| RpcError { + data: Some(e.to_string().into()), + ..yajrc::INVALID_PARAMS_ERROR + }) +} + +pub fn combine(v1: Value, v2: Value) -> Result { + let (Value::Object(mut v1), Value::Object(v2)) = (v1, v2) else { + return Err(RpcError { + data: Some("params must be object".into()), + ..yajrc::INVALID_PARAMS_ERROR + }); + }; + for (key, value) in v2 { + if v1.insert(key.clone(), value).is_some() { + return Err(RpcError { + data: Some(format!("duplicate key: {key}").into()), + ..yajrc::INVALID_PARAMS_ERROR + }); + } + } + Ok(Value::Object(v1)) +} + +pub fn invalid_params(e: imbl_value::Error) -> RpcError { + RpcError { + data: Some(e.to_string().into()), + ..yajrc::INVALID_PARAMS_ERROR + } +} + +pub fn parse_error(e: imbl_value::Error) -> RpcError { + RpcError { + data: Some(e.to_string().into()), + ..yajrc::PARSE_ERROR + } +} + +pub struct Flat(pub A, pub B); +impl<'de, A, B> Deserialize<'de> for Flat +where + A: DeserializeOwned, + B: DeserializeOwned, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let v = Value::deserialize(deserializer)?; + let a = imbl_value::from_value(v.clone()).map_err(serde::de::Error::custom)?; + let b = imbl_value::from_value(v).map_err(serde::de::Error::custom)?; + Ok(Flat(a, b)) + } +} diff --git a/rpc-toolkit/tests/compat.rs b/rpc-toolkit/tests/compat.rs new file mode 100644 index 0000000..c3d59b1 --- /dev/null +++ b/rpc-toolkit/tests/compat.rs @@ -0,0 +1,214 @@ +use std::fmt::Display; +use std::str::FromStr; +use std::sync::Arc; + +use futures::FutureExt; +use hyper::Request; +use rpc_toolkit::clap::Arg; +use rpc_toolkit::hyper::http::Error as HttpError; +use rpc_toolkit::hyper::{Body, Response}; +use rpc_toolkit::rpc_server_helpers::{ + DynMiddlewareStage2, DynMiddlewareStage3, DynMiddlewareStage4, +}; +use rpc_toolkit::serde::{Deserialize, Serialize}; +use rpc_toolkit::url::Host; +use rpc_toolkit::yajrc::RpcError; +use rpc_toolkit::{command, rpc_server, run_cli, Context, Metadata}; + +#[derive(Debug, Clone)] +pub struct AppState(Arc); +impl From for () { + fn from(_: AppState) -> Self { + () + } +} + +#[derive(Debug)] +pub struct ConfigSeed { + host: Host, + port: u16, +} + +impl Context for AppState { + fn host(&self) -> Host<&str> { + match &self.0.host { + Host::Domain(s) => Host::Domain(s.as_str()), + Host::Ipv4(i) => Host::Ipv4(*i), + Host::Ipv6(i) => Host::Ipv6(*i), + } + } + fn port(&self) -> u16 { + self.0.port + } +} + +fn test_string() -> String { + "test".to_owned() +} + +#[command( + about = "Does the thing", + subcommands("dothething2::", self(dothething_impl(async))) +)] +async fn dothething< + U: Serialize + for<'a> Deserialize<'a> + FromStr + Clone + 'static, + E: Display, +>( + #[context] _ctx: AppState, + #[arg(short = 'a')] arg1: Option, + #[arg(short = 'b', default = "test_string")] val: String, + #[arg(short = 'c', help = "I am the flag `c`!", default)] arg3: bool, + #[arg(stdin)] structured: U, +) -> Result<(Option, String, bool, U), RpcError> { + Ok((arg1, val, arg3, structured)) +} + +async fn dothething_impl( + ctx: AppState, + parent_data: (Option, String, bool, U), +) -> Result { + Ok(format!( + "{:?}, {:?}, {}, {}, {}", + ctx, + parent_data.0, + parent_data.1, + parent_data.2, + serde_json::to_string_pretty(&parent_data.3)? + )) +} + +#[command(about = "Does the thing")] +fn dothething2 Deserialize<'a> + FromStr, E: Display>( + #[parent_data] parent_data: (Option, String, bool, U), + #[arg(stdin)] structured2: U, +) -> Result { + Ok(format!( + "{:?}, {}, {}, {}, {}", + parent_data.0, + parent_data.1, + parent_data.2, + serde_json::to_string_pretty(&parent_data.3)?, + serde_json::to_string_pretty(&structured2)?, + )) +} + +async fn cors( + req: &mut Request, + _: M, +) -> Result>, HttpError> { + if req.method() == hyper::Method::OPTIONS { + Ok(Err(Response::builder() + .header("Access-Control-Allow-Origin", "*") + .body(Body::empty())?)) + } else { + Ok(Ok(Box::new(|_, _| { + async move { + let res: DynMiddlewareStage3 = Box::new(|_, _| { + async move { + let res: DynMiddlewareStage4 = Box::new(|res| { + async move { + res.headers_mut() + .insert("Access-Control-Allow-Origin", "*".parse()?); + Ok::<_, HttpError>(()) + } + .boxed() + }); + Ok::<_, HttpError>(Ok(res)) + } + .boxed() + }); + Ok::<_, HttpError>(Ok(res)) + } + .boxed() + }))) + } +} + +#[tokio::test] +async fn test_rpc() { + use tokio::io::AsyncWriteExt; + + let seed = Arc::new(ConfigSeed { + host: Host::parse("localhost").unwrap(), + port: 8000, + }); + let server = rpc_server!({ + command: dothething::, + context: AppState(seed), + middleware: [ + cors, + ], + }); + let handle = tokio::spawn(server); + let mut cmd = tokio::process::Command::new("cargo") + .arg("test") + .arg("--package") + .arg("rpc-toolkit") + .arg("--test") + .arg("test") + .arg("--") + .arg("cli_test") + .arg("--exact") + .arg("--nocapture") + .arg("--") + // .arg("-b") + // .arg("test") + .arg("dothething2") + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .spawn() + .unwrap(); + cmd.stdin + .take() + .unwrap() + .write_all(b"TEST\nHAHA") + .await + .unwrap(); + let out = cmd.wait_with_output().await.unwrap(); + assert!(out.status.success()); + assert!(dbg!(std::str::from_utf8(&out.stdout).unwrap()) + .contains("\nNone, test, false, \"TEST\", \"HAHA\"\n")); + handle.abort(); +} + +#[test] +fn cli_test() { + let app = dothething::build_app(); + let mut skip = true; + let args = std::iter::once(std::ffi::OsString::from("cli_test")) + .chain(std::env::args_os().into_iter().skip_while(|a| { + if a == "--" { + skip = false; + return true; + } + skip + })) + .collect::>(); + if skip { + return; + } + let matches = app.get_matches_from(args); + let seed = Arc::new(ConfigSeed { + host: Host::parse("localhost").unwrap(), + port: 8000, + }); + dothething::cli_handler::(AppState(seed), (), None, &matches, "".into(), ()) + .unwrap(); +} + +#[test] +#[ignore] +fn cli_example() { + run_cli! ({ + command: dothething::, + app: app => app + .arg(Arg::with_name("host").long("host").short('h').takes_value(true)) + .arg(Arg::with_name("port").long("port").short('p').takes_value(true)), + context: matches => AppState(Arc::new(ConfigSeed { + host: Host::parse(matches.value_of("host").unwrap_or("localhost")).unwrap(), + port: matches.value_of("port").unwrap_or("8000").parse().unwrap(), + })) + }) +} + +//////////////////////////////////////////////// diff --git a/rpc-toolkit/tests/test.rs b/rpc-toolkit/tests/test.rs index c3d59b1..34922a1 100644 --- a/rpc-toolkit/tests/test.rs +++ b/rpc-toolkit/tests/test.rs @@ -1,214 +1 @@ -use std::fmt::Display; -use std::str::FromStr; -use std::sync::Arc; - -use futures::FutureExt; -use hyper::Request; -use rpc_toolkit::clap::Arg; -use rpc_toolkit::hyper::http::Error as HttpError; -use rpc_toolkit::hyper::{Body, Response}; -use rpc_toolkit::rpc_server_helpers::{ - DynMiddlewareStage2, DynMiddlewareStage3, DynMiddlewareStage4, -}; -use rpc_toolkit::serde::{Deserialize, Serialize}; -use rpc_toolkit::url::Host; -use rpc_toolkit::yajrc::RpcError; -use rpc_toolkit::{command, rpc_server, run_cli, Context, Metadata}; - -#[derive(Debug, Clone)] -pub struct AppState(Arc); -impl From for () { - fn from(_: AppState) -> Self { - () - } -} - -#[derive(Debug)] -pub struct ConfigSeed { - host: Host, - port: u16, -} - -impl Context for AppState { - fn host(&self) -> Host<&str> { - match &self.0.host { - Host::Domain(s) => Host::Domain(s.as_str()), - Host::Ipv4(i) => Host::Ipv4(*i), - Host::Ipv6(i) => Host::Ipv6(*i), - } - } - fn port(&self) -> u16 { - self.0.port - } -} - -fn test_string() -> String { - "test".to_owned() -} - -#[command( - about = "Does the thing", - subcommands("dothething2::", self(dothething_impl(async))) -)] -async fn dothething< - U: Serialize + for<'a> Deserialize<'a> + FromStr + Clone + 'static, - E: Display, ->( - #[context] _ctx: AppState, - #[arg(short = 'a')] arg1: Option, - #[arg(short = 'b', default = "test_string")] val: String, - #[arg(short = 'c', help = "I am the flag `c`!", default)] arg3: bool, - #[arg(stdin)] structured: U, -) -> Result<(Option, String, bool, U), RpcError> { - Ok((arg1, val, arg3, structured)) -} - -async fn dothething_impl( - ctx: AppState, - parent_data: (Option, String, bool, U), -) -> Result { - Ok(format!( - "{:?}, {:?}, {}, {}, {}", - ctx, - parent_data.0, - parent_data.1, - parent_data.2, - serde_json::to_string_pretty(&parent_data.3)? - )) -} - -#[command(about = "Does the thing")] -fn dothething2 Deserialize<'a> + FromStr, E: Display>( - #[parent_data] parent_data: (Option, String, bool, U), - #[arg(stdin)] structured2: U, -) -> Result { - Ok(format!( - "{:?}, {}, {}, {}, {}", - parent_data.0, - parent_data.1, - parent_data.2, - serde_json::to_string_pretty(&parent_data.3)?, - serde_json::to_string_pretty(&structured2)?, - )) -} - -async fn cors( - req: &mut Request, - _: M, -) -> Result>, HttpError> { - if req.method() == hyper::Method::OPTIONS { - Ok(Err(Response::builder() - .header("Access-Control-Allow-Origin", "*") - .body(Body::empty())?)) - } else { - Ok(Ok(Box::new(|_, _| { - async move { - let res: DynMiddlewareStage3 = Box::new(|_, _| { - async move { - let res: DynMiddlewareStage4 = Box::new(|res| { - async move { - res.headers_mut() - .insert("Access-Control-Allow-Origin", "*".parse()?); - Ok::<_, HttpError>(()) - } - .boxed() - }); - Ok::<_, HttpError>(Ok(res)) - } - .boxed() - }); - Ok::<_, HttpError>(Ok(res)) - } - .boxed() - }))) - } -} - -#[tokio::test] -async fn test_rpc() { - use tokio::io::AsyncWriteExt; - - let seed = Arc::new(ConfigSeed { - host: Host::parse("localhost").unwrap(), - port: 8000, - }); - let server = rpc_server!({ - command: dothething::, - context: AppState(seed), - middleware: [ - cors, - ], - }); - let handle = tokio::spawn(server); - let mut cmd = tokio::process::Command::new("cargo") - .arg("test") - .arg("--package") - .arg("rpc-toolkit") - .arg("--test") - .arg("test") - .arg("--") - .arg("cli_test") - .arg("--exact") - .arg("--nocapture") - .arg("--") - // .arg("-b") - // .arg("test") - .arg("dothething2") - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .spawn() - .unwrap(); - cmd.stdin - .take() - .unwrap() - .write_all(b"TEST\nHAHA") - .await - .unwrap(); - let out = cmd.wait_with_output().await.unwrap(); - assert!(out.status.success()); - assert!(dbg!(std::str::from_utf8(&out.stdout).unwrap()) - .contains("\nNone, test, false, \"TEST\", \"HAHA\"\n")); - handle.abort(); -} - -#[test] -fn cli_test() { - let app = dothething::build_app(); - let mut skip = true; - let args = std::iter::once(std::ffi::OsString::from("cli_test")) - .chain(std::env::args_os().into_iter().skip_while(|a| { - if a == "--" { - skip = false; - return true; - } - skip - })) - .collect::>(); - if skip { - return; - } - let matches = app.get_matches_from(args); - let seed = Arc::new(ConfigSeed { - host: Host::parse("localhost").unwrap(), - port: 8000, - }); - dothething::cli_handler::(AppState(seed), (), None, &matches, "".into(), ()) - .unwrap(); -} - -#[test] -#[ignore] -fn cli_example() { - run_cli! ({ - command: dothething::, - app: app => app - .arg(Arg::with_name("host").long("host").short('h').takes_value(true)) - .arg(Arg::with_name("port").long("port").short('p').takes_value(true)), - context: matches => AppState(Arc::new(ConfigSeed { - host: Host::parse(matches.value_of("host").unwrap_or("localhost")).unwrap(), - port: matches.value_of("port").unwrap_or("8000").parse().unwrap(), - })) - }) -} - -//////////////////////////////////////////////// +pub struct App;