macros wip

This commit is contained in:
Aiden McClelland
2023-12-27 22:53:59 -07:00
parent 1442d36e5e
commit 434d521c74
25 changed files with 2157 additions and 2053 deletions

37
Cargo.lock generated
View File

@@ -251,6 +251,12 @@ version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f"
[[package]]
name = "either"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
[[package]]
name = "encoding_rs"
version = "0.8.33"
@@ -675,6 +681,15 @@ version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
[[package]]
name = "itertools"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.10"
@@ -888,6 +903,26 @@ version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "pin-project"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.41",
]
[[package]]
name = "pin-project-lite"
version = "0.2.13"
@@ -1001,6 +1036,7 @@ dependencies = [
"lazy_format",
"lazy_static",
"openssl",
"pin-project",
"reqwest",
"rpc-toolkit-macro",
"serde",
@@ -1026,6 +1062,7 @@ dependencies = [
name = "rpc-toolkit-macro-internals"
version = "0.2.2"
dependencies = [
"itertools",
"proc-macro2",
"quote",
"syn 1.0.109",

View File

@@ -11,3 +11,4 @@ license = "MIT"
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "1.0", features = ["full", "fold"] }
itertools = "0.12"

File diff suppressed because it is too large Load Diff

View File

@@ -25,7 +25,7 @@ impl Default for ExecutionContext {
#[derive(Default)]
pub struct LeafOptions {
macro_debug: bool,
macro_debug: Option<Path>,
blocking: Option<Path>,
is_async: bool,
aliases: Vec<LitStr>,
@@ -34,6 +34,7 @@ pub struct LeafOptions {
exec_ctx: ExecutionContext,
display: Option<Path>,
metadata: HashMap<Ident, Lit>,
clap_attr: Vec<NestedMeta>,
}
pub struct SelfImplInfo {
@@ -79,22 +80,18 @@ impl Options {
}
}
#[derive(Clone)]
pub struct ArgOptions {
ty: Type,
optional: bool,
check_is_present: bool,
help: Option<LitStr>,
name: Option<Ident>,
rename: Option<LitStr>,
short: Option<LitChar>,
long: Option<LitStr>,
parse: Option<Path>,
default: Option<Option<LitStr>>,
count: Option<Path>,
multiple: Option<Path>,
stdin: Option<Path>,
default: Option<Path>,
clap_attr: Vec<NestedMeta>,
}
#[derive(Clone)]
pub enum ParamType {
None,
Arg(ArgOptions),

View File

@@ -7,7 +7,7 @@ pub fn parse_command_attr(args: AttributeArgs) -> Result<Options> {
for arg in args {
match arg {
NestedMeta::Meta(Meta::Path(p)) if p.is_ident("macro_debug") => {
opt.common().macro_debug = true;
opt.common().macro_debug = Some(p);
}
NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("subcommands") => {
let inner = opt.to_parent()?;
@@ -536,25 +536,20 @@ pub fn parse_command_attr(args: AttributeArgs) -> Result<Options> {
pub fn parse_arg_attr(attr: Attribute, arg: PatType) -> Result<ArgOptions> {
let arg_span = arg.span();
let meta = attr.parse_meta()?;
let mut opt = ArgOptions {
ty: *arg.ty,
optional: false,
check_is_present: false,
help: None,
name: match *arg.pat {
Pat::Ident(i) => Some(i.ident),
_ => None,
},
rename: None,
short: None,
long: None,
parse: None,
default: None,
count: None,
multiple: None,
stdin: None,
default: None,
clap_attr: Vec::new(),
};
match attr.parse_meta()? {
match meta {
Meta::List(list) => {
for arg in list.nested {
match arg {
@@ -591,34 +586,10 @@ pub fn parse_arg_attr(attr: Attribute, arg: PatType) -> Result<ArgOptions> {
return Err(Error::new(nv.path.span(), "`parse` cannot be assigned to"));
}
NestedMeta::Meta(Meta::Path(p)) if p.is_ident("stdin") => {
if opt.short.is_some() {
if !opt.clap_attr.is_empty() {
return Err(Error::new(
p.span(),
"`stdin` and `short` are mutually exclusive",
));
}
if opt.long.is_some() {
return Err(Error::new(
p.span(),
"`stdin` and `long` are mutually exclusive",
));
}
if opt.help.is_some() {
return Err(Error::new(
p.span(),
"`stdin` and `help` are mutually exclusive",
));
}
if opt.count.is_some() {
return Err(Error::new(
p.span(),
"`stdin` and `count` are mutually exclusive",
));
}
if opt.multiple.is_some() {
return Err(Error::new(
p.span(),
"`stdin` and `multiple` are mutually exclusive",
"`stdin` and clap parser attributes are mutually exclusive",
));
}
opt.stdin = Some(p);
@@ -632,79 +603,6 @@ pub fn parse_arg_attr(attr: Attribute, arg: PatType) -> Result<ArgOptions> {
NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("stdin") => {
return Err(Error::new(nv.path.span(), "`stdin` cannot be assigned to"));
}
NestedMeta::Meta(Meta::Path(p)) if p.is_ident("count") => {
if opt.stdin.is_some() {
return Err(Error::new(
p.span(),
"`stdin` and `count` are mutually exclusive",
));
}
if opt.multiple.is_some() {
return Err(Error::new(
p.span(),
"`count` and `multiple` are mutually exclusive",
));
}
opt.count = Some(p);
}
NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("count") => {
return Err(Error::new(
list.path.span(),
"`count` does not take any arguments",
));
}
NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("count") => {
return Err(Error::new(nv.path.span(), "`count` cannot be assigned to"));
}
NestedMeta::Meta(Meta::Path(p)) if p.is_ident("multiple") => {
if opt.stdin.is_some() {
return Err(Error::new(
p.span(),
"`stdin` and `multiple` are mutually exclusive",
));
}
if opt.count.is_some() {
return Err(Error::new(
p.span(),
"`count` and `multiple` are mutually exclusive",
));
}
opt.multiple = Some(p);
}
NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("multiple") => {
return Err(Error::new(
list.path.span(),
"`multiple` does not take any arguments",
));
}
NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("count") => {
return Err(Error::new(nv.path.span(), "`count` cannot be assigned to"));
}
NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("help") => {
if let Lit::Str(help) = nv.lit {
if opt.help.is_some() {
return Err(Error::new(help.span(), "duplicate argument `help`"));
}
if opt.stdin.is_some() {
return Err(Error::new(
help.span(),
"`stdin` and `help` are mutually exclusive",
));
}
opt.help = Some(help);
} else {
return Err(Error::new(nv.lit.span(), "help message must be a string"));
}
}
NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("help") => {
return Err(Error::new(
list.path.span(),
"`help` does not take any arguments",
));
}
NestedMeta::Meta(Meta::Path(p)) if p.is_ident("help") => {
return Err(Error::new(p.span(), "`help` must be assigned to"));
}
NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("rename") => {
if let Lit::Str(rename) = nv.lit {
if opt.rename.is_some() {
@@ -713,6 +611,12 @@ pub fn parse_arg_attr(attr: Attribute, arg: PatType) -> Result<ArgOptions> {
"duplicate argument `rename`",
));
}
opt.clap_attr
.push(NestedMeta::Meta(Meta::NameValue(MetaNameValue {
path: Path::from(Ident::new("name", nv.path.span())),
eq_token: nv.eq_token,
lit: Lit::Str(rename.clone()),
})));
opt.rename = Some(rename);
} else {
return Err(Error::new(nv.lit.span(), "`rename` must be a string"));
@@ -727,68 +631,8 @@ pub fn parse_arg_attr(attr: Attribute, arg: PatType) -> Result<ArgOptions> {
NestedMeta::Meta(Meta::Path(p)) if p.is_ident("rename") => {
return Err(Error::new(p.span(), "`rename` must be assigned to"));
}
NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("short") => {
if let Lit::Char(short) = nv.lit {
if opt.short.is_some() {
return Err(Error::new(short.span(), "duplicate argument `short`"));
}
if opt.stdin.is_some() {
return Err(Error::new(
short.span(),
"`stdin` and `short` are mutually exclusive",
));
}
opt.short = Some(short);
} else {
return Err(Error::new(nv.lit.span(), "`short` must be a char"));
}
}
NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("short") => {
return Err(Error::new(
list.path.span(),
"`short` does not take any arguments",
));
}
NestedMeta::Meta(Meta::Path(p)) if p.is_ident("short") => {
return Err(Error::new(p.span(), "`short` must be assigned to"));
}
NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("long") => {
if let Lit::Str(long) = nv.lit {
if opt.long.is_some() {
return Err(Error::new(long.span(), "duplicate argument `long`"));
}
if opt.stdin.is_some() {
return Err(Error::new(
long.span(),
"`stdin` and `long` are mutually exclusive",
));
}
opt.long = Some(long);
} else {
return Err(Error::new(nv.lit.span(), "`long` must be a string"));
}
}
NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("long") => {
return Err(Error::new(
list.path.span(),
"`long` does not take any arguments",
));
}
NestedMeta::Meta(Meta::Path(p)) if p.is_ident("long") => {
return Err(Error::new(p.span(), "`long` must be assigned to"));
}
NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("default") => {
if let Lit::Str(default) = nv.lit {
if opt.default.is_some() {
return Err(Error::new(
default.span(),
"duplicate argument `default`",
));
}
opt.default = Some(Some(default));
} else {
return Err(Error::new(nv.lit.span(), "`default` must be a string"));
}
return Err(Error::new(nv.lit.span(), "`default` cannot be assigned to"));
}
NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("default") => {
return Err(Error::new(
@@ -797,13 +641,21 @@ pub fn parse_arg_attr(attr: Attribute, arg: PatType) -> Result<ArgOptions> {
));
}
NestedMeta::Meta(Meta::Path(p)) if p.is_ident("default") => {
if opt.default.is_some() {
return Err(Error::new(p.span(), "duplicate argument `default`"));
}
opt.default = Some(None);
opt.clap_attr
.push(NestedMeta::Meta(Meta::Path(Path::from(Ident::new(
"default_value_t",
p.span(),
)))));
opt.default = Some(p);
}
_ => {
return Err(Error::new(arg.span(), "unknown argument"));
unknown => {
if opt.stdin.is_some() {
return Err(Error::new(
unknown.span(),
"`stdin` and clap parser attributes are mutually exclusive",
));
}
opt.clap_attr.push(unknown);
}
}
}

View File

@@ -8,14 +8,5 @@ macro_rules! macro_try {
}
mod command;
mod rpc_handler;
mod rpc_server;
mod run_cli;
pub use command::build::build as build_command;
pub use rpc_handler::build::build as build_rpc_handler;
pub use rpc_handler::RpcHandlerArgs;
pub use rpc_server::build::build as build_rpc_server;
pub use rpc_server::RpcServerArgs;
pub use run_cli::build::build as build_run_cli;
pub use run_cli::RunCliArgs;

View File

@@ -1,129 +0,0 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::spanned::Spanned;
use super::*;
pub fn build(args: RpcHandlerArgs) -> TokenStream {
let mut command = args.command;
let mut arguments = std::mem::replace(
&mut command.segments.last_mut().unwrap().arguments,
PathArguments::None,
);
let command_module = command.clone();
if let PathArguments::AngleBracketed(a) = &mut arguments {
a.args.push(syn::parse2(quote! { _ }).unwrap());
}
command.segments.push(PathSegment {
ident: Ident::new("rpc_handler", command.span()),
arguments,
});
let ctx = args.ctx;
let parent_data = if let Some(data) = args.parent_data {
quote! { #data }
} else {
quote! { () }
};
let status_fn = args.status_fn.unwrap_or_else(|| {
syn::parse2(quote! { |_| ::rpc_toolkit::hyper::StatusCode::OK }).unwrap()
});
let middleware_name_clone = (0..)
.map(|i| {
Ident::new(
&format!("__rpc_toolkit__rpc_handler__middleware_clone_{}", i),
Span::call_site(),
)
})
.take(args.middleware.len());
let middleware_name_clone2 = middleware_name_clone.clone();
let middleware_name_clone3 = middleware_name_clone.clone();
let middleware_name_clone4 = middleware_name_clone.clone();
let middleware_name_pre = (0..)
.map(|i| Ident::new(&format!("middleware_pre_{}", i), Span::call_site()))
.take(args.middleware.len());
let middleware_name_pre2 = middleware_name_pre.clone();
let middleware_name_post = (0..)
.map(|i| Ident::new(&format!("middleware_post_{}", i), Span::call_site()))
.take(args.middleware.len());
let middleware_name_post_inv = middleware_name_post
.clone()
.collect::<Vec<_>>()
.into_iter()
.rev();
let middleware_name = (0..)
.map(|i| Ident::new(&format!("middleware_{}", i), Span::call_site()))
.take(args.middleware.len());
let middleware_name2 = middleware_name.clone();
let middleware = args.middleware.iter();
let res = quote! {
{
let __rpc_toolkit__rpc_handler__context = #ctx;
let __rpc_toolkit__rpc_handler__parent_data = #parent_data;
let __rpc_toolkit__rpc_handler__status_fn = #status_fn;
#(
let #middleware_name_clone = ::std::sync::Arc::new(#middleware);
)*
let res: ::rpc_toolkit::RpcHandler = ::std::sync::Arc::new(move |mut req| {
let ctx = __rpc_toolkit__rpc_handler__context.clone();
let parent_data = __rpc_toolkit__rpc_handler__parent_data.clone();
let metadata = #command_module::Metadata::default();
#(
let #middleware_name_clone3 = #middleware_name_clone2.clone();
)*
::rpc_toolkit::futures::FutureExt::boxed(async move {
#(
let #middleware_name_pre = match ::rpc_toolkit::rpc_server_helpers::constrain_middleware(&*#middleware_name_clone4)(&mut req, metadata).await? {
Ok(a) => a,
Err(res) => return Ok(res),
};
)*
let (mut req_parts, req_body) = req.into_parts();
let (mut res_parts, _) = ::rpc_toolkit::hyper::Response::new(()).into_parts();
let rpc_req = ::rpc_toolkit::rpc_server_helpers::make_request(&req_parts, req_body).await;
match rpc_req {
Ok(mut rpc_req) => {
#(
let #middleware_name_post = match #middleware_name_pre2(&mut req_parts, &mut rpc_req).await? {
Ok(a) => a,
Err(res) => return Ok(res),
};
)*
let mut rpc_res = match ::rpc_toolkit::serde_json::from_value(::rpc_toolkit::serde_json::Value::Object(rpc_req.params)) {
Ok(params) => #command(ctx, parent_data, &req_parts, &mut res_parts, ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), params).await,
Err(e) => Err(e.into())
};
#(
let #middleware_name = match #middleware_name_post_inv(&mut res_parts, &mut rpc_res).await? {
Ok(a) => a,
Err(res) => return Ok(res),
};
)*
let mut res = ::rpc_toolkit::rpc_server_helpers::to_response(
&req_parts.headers,
res_parts,
Ok((
rpc_req.id,
rpc_res,
)),
__rpc_toolkit__rpc_handler__status_fn,
)?;
#(
#middleware_name2(&mut res).await?;
)*
Ok::<_, ::rpc_toolkit::hyper::http::Error>(res)
}
Err(e) => ::rpc_toolkit::rpc_server_helpers::to_response(
&req_parts.headers,
res_parts,
Err(e),
__rpc_toolkit__rpc_handler__status_fn,
),
}
})
});
res
}
};
// panic!("{}", res);
res
}

View File

@@ -1,12 +0,0 @@
use syn::*;
pub struct RpcHandlerArgs {
pub(crate) command: Path,
pub(crate) ctx: Expr,
pub(crate) parent_data: Option<Expr>,
pub(crate) status_fn: Option<Expr>,
pub(crate) middleware: punctuated::Punctuated<Expr, token::Comma>,
}
pub mod build;
mod parse;

View File

@@ -1,50 +0,0 @@
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use super::*;
impl Parse for RpcHandlerArgs {
fn parse(input: ParseStream) -> Result<Self> {
let args;
braced!(args in input);
let mut command = None;
let mut ctx = None;
let mut parent_data = None;
let mut status_fn = None;
let mut middleware = Punctuated::new();
while !args.is_empty() {
let arg_name: syn::Ident = args.parse()?;
let _: token::Colon = args.parse()?;
match arg_name.to_string().as_str() {
"command" => {
command = Some(args.parse()?);
}
"context" => {
ctx = Some(args.parse()?);
}
"parent_data" => {
parent_data = Some(args.parse()?);
}
"status" => {
status_fn = Some(args.parse()?);
}
"middleware" => {
let middlewares;
bracketed!(middlewares in args);
middleware = middlewares.parse_terminated(Expr::parse)?;
}
_ => return Err(Error::new(arg_name.span(), "unknown argument")),
}
if !args.is_empty() {
let _: token::Comma = args.parse()?;
}
}
Ok(RpcHandlerArgs {
command: command.expect("`command` is required"),
ctx: ctx.expect("`context` is required"),
parent_data,
status_fn,
middleware,
})
}
}

View File

@@ -1,25 +0,0 @@
use proc_macro2::TokenStream;
use quote::quote;
use super::*;
pub fn build(mut args: RpcServerArgs) -> TokenStream {
let ctx = std::mem::replace(
&mut args.ctx,
parse2(quote! { __rpc_toolkit__rpc_server__context }).unwrap(),
);
let handler = crate::rpc_handler::build::build(args);
let res = quote! {
{
let __rpc_toolkit__rpc_server__context = #ctx;
let __rpc_toolkit__rpc_server__builder = ::rpc_toolkit::rpc_server_helpers::make_builder(&__rpc_toolkit__rpc_server__context);
let handler = #handler;
__rpc_toolkit__rpc_server__builder.serve(::rpc_toolkit::hyper::service::make_service_fn(move |_| {
let handler = handler.clone();
async move { Ok::<_, ::std::convert::Infallible>(::rpc_toolkit::hyper::service::service_fn(move |req| handler(req))) }
}))
}
};
// panic!("{}", res);
res
}

View File

@@ -1,5 +0,0 @@
use syn::*;
pub type RpcServerArgs = crate::RpcHandlerArgs;
pub mod build;

View File

@@ -1,81 +0,0 @@
use proc_macro2::TokenStream;
use quote::quote;
use syn::spanned::Spanned;
use super::*;
pub fn build(args: RunCliArgs) -> TokenStream {
let mut command_handler = args.command.clone();
let mut arguments = std::mem::replace(
&mut command_handler.segments.last_mut().unwrap().arguments,
PathArguments::None,
);
let command = command_handler.clone();
if let PathArguments::AngleBracketed(a) = &mut arguments {
a.args.push(syn::parse2(quote! { () }).unwrap());
a.args.push(syn::parse2(quote! { _ }).unwrap());
}
command_handler.segments.push(PathSegment {
ident: Ident::new("cli_handler", command.span()),
arguments,
});
let app = if let Some(mut_app) = args.mut_app {
let ident = mut_app.app_ident;
let body = mut_app.body;
quote! {
{
let #ident = #command::build_app();
#body
}
}
} else {
quote! { #command::build_app() }
};
let make_ctx = if let Some(make_ctx) = args.make_ctx {
let ident = make_ctx.matches_ident;
let body = make_ctx.body;
quote! {
{
let #ident = &rpc_toolkit_matches;
#body
}
}
} else {
quote! { &rpc_toolkit_matches }
};
let parent_data = if let Some(data) = args.parent_data {
quote! { #data }
} else {
quote! { () }
};
let exit_fn = args.exit_fn.unwrap_or_else(|| {
syn::parse2(quote! { |err: ::rpc_toolkit::yajrc::RpcError| {
eprintln!("{}", err.message);
if let Some(data) = err.data {
eprintln!("{}", data);
}
std::process::exit(err.code);
} })
.unwrap()
});
quote! {
{
let rpc_toolkit_matches = #app.get_matches();
let rpc_toolkit_ctx = #make_ctx;
let rpc_toolkit_parent_data = #parent_data;
if let Err(err) = #command_handler(
rpc_toolkit_ctx,
rpc_toolkit_parent_data,
None,
&rpc_toolkit_matches,
"".into(),
(),
) {
drop(rpc_toolkit_matches);
(#exit_fn)(err);
} else {
drop(rpc_toolkit_matches);
}
}
}
}

View File

@@ -1,22 +0,0 @@
use syn::*;
pub struct MakeCtx {
matches_ident: Ident,
body: Expr,
}
pub struct MutApp {
app_ident: Ident,
body: Expr,
}
pub struct RunCliArgs {
command: Path,
mut_app: Option<MutApp>,
make_ctx: Option<MakeCtx>,
parent_data: Option<Expr>,
exit_fn: Option<Expr>,
}
pub mod build;
mod parse;

View File

@@ -1,68 +0,0 @@
use syn::parse::{Parse, ParseStream};
use super::*;
impl Parse for MakeCtx {
fn parse(input: ParseStream) -> Result<Self> {
let matches_ident = input.parse()?;
let _: token::FatArrow = input.parse()?;
let body = input.parse()?;
Ok(MakeCtx {
matches_ident,
body,
})
}
}
impl Parse for MutApp {
fn parse(input: ParseStream) -> Result<Self> {
let app_ident = input.parse()?;
let _: token::FatArrow = input.parse()?;
let body = input.parse()?;
Ok(MutApp { app_ident, body })
}
}
impl Parse for RunCliArgs {
fn parse(input: ParseStream) -> Result<Self> {
let args;
braced!(args in input);
let mut command = None;
let mut mut_app = None;
let mut make_ctx = None;
let mut parent_data = None;
let mut exit_fn = None;
while !args.is_empty() {
let arg_name: syn::Ident = args.parse()?;
let _: token::Colon = args.parse()?;
match arg_name.to_string().as_str() {
"command" => {
command = Some(args.parse()?);
}
"app" => {
mut_app = Some(args.parse()?);
}
"context" => {
make_ctx = Some(args.parse()?);
}
"parent_data" => {
parent_data = Some(args.parse()?);
}
"exit" => {
exit_fn = Some(args.parse()?);
}
_ => return Err(Error::new(arg_name.span(), "unknown argument")),
}
if !args.is_empty() {
let _: token::Comma = args.parse()?;
}
}
Ok(RunCliArgs {
command: command.expect("`command` is required"),
mut_app,
make_ctx,
parent_data,
exit_fn,
})
}
}

View File

@@ -42,21 +42,3 @@ pub fn context(_: TokenStream, _: TokenStream) -> TokenStream {
.to_compile_error()
.into()
}
#[proc_macro]
pub fn rpc_handler(item: TokenStream) -> TokenStream {
let item = syn::parse_macro_input!(item as RpcHandlerArgs);
build_rpc_handler(item).into()
}
#[proc_macro]
pub fn rpc_server(item: TokenStream) -> TokenStream {
let item = syn::parse_macro_input!(item as RpcServerArgs);
build_rpc_server(item).into()
}
#[proc_macro]
pub fn run_cli(item: TokenStream) -> TokenStream {
let item = syn::parse_macro_input!(item as RunCliArgs);
build_run_cli(item).into()
}

View File

@@ -27,6 +27,7 @@ imbl-value = "0.1"
lazy_format = "2"
lazy_static = "1.4"
openssl = { version = "0.10", features = ["vendored"] }
pin-project = "1"
reqwest = { version = "0.11" }
rpc-toolkit-macro = { version = "0.2.2", path = "../rpc-toolkit-macro" }
serde = { version = "1.0", features = ["derive"] }

View File

@@ -13,8 +13,8 @@ use yajrc::{Id, RpcError};
use crate::util::{internal_error, parse_error, Flat};
use crate::{
AnyHandler, CliBindingsAny, DynHandler, HandleAny, HandleAnyArgs, HandleArgs, Handler,
IntoContext, Name, ParentHandler,
AnyHandler, CliBindings, CliBindingsAny, DynHandler, HandleAny, HandleAnyArgs, HandleArgs,
Handler, HandlerTypes, IntoContext, Name, ParentHandler,
};
type GenericRpcMethod<'a> = yajrc::GenericRpcMethod<&'a str, Value, Value>;
@@ -81,120 +81,123 @@ impl<Context: crate::Context + Clone, Config: CommandFactory + FromArgMatches>
}
#[async_trait::async_trait]
pub trait CliContext: crate::Context {
pub trait CallRemote: crate::Context {
async fn call_remote(&self, method: &str, params: Value) -> Result<Value, RpcError>;
}
#[async_trait::async_trait]
pub trait CliContextHttp: crate::Context {
fn client(&self) -> &Client;
fn url(&self) -> Url;
async fn call_remote(&self, method: &str, params: Value) -> Result<Value, RpcError> {
let rpc_req = RpcRequest {
id: Some(Id::Number(0.into())),
method: GenericRpcMethod::new(method),
params,
};
let mut req = self.client().request(Method::POST, self.url());
let body;
pub async fn call_remote_http(
client: &Client,
url: Url,
method: &str,
params: Value,
) -> Result<Value, RpcError> {
let rpc_req = RpcRequest {
id: Some(Id::Number(0.into())),
method: GenericRpcMethod::new(method),
params,
};
let mut req = client.request(Method::POST, url);
let body;
#[cfg(feature = "cbor")]
{
req = req.header("content-type", "application/cbor");
req = req.header("accept", "application/cbor, application/json");
body = serde_cbor::to_vec(&rpc_req)?;
}
#[cfg(not(feature = "cbor"))]
{
req = req.header("content-type", "application/json");
req = req.header("accept", "application/json");
body = serde_json::to_vec(&req)?;
}
let res = req
.header("content-length", body.len())
.body(body)
.send()
.await?;
match res
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
{
Some("application/json") => {
serde_json::from_slice::<RpcResponse>(&*res.bytes().await.map_err(internal_error)?)
.map_err(parse_error)?
.result
}
#[cfg(feature = "cbor")]
{
req = req.header("content-type", "application/cbor");
req = req.header("accept", "application/cbor, application/json");
body = serde_cbor::to_vec(&rpc_req)?;
Some("application/cbor") => {
serde_cbor::from_slice::<RpcResponse>(&*res.bytes().await.map_err(internal_error)?)
.map_err(parse_error)?
.result
}
#[cfg(not(feature = "cbor"))]
{
req = req.header("content-type", "application/json");
req = req.header("accept", "application/json");
body = serde_json::to_vec(&req)?;
}
let res = req
.header("content-length", body.len())
.body(body)
.send()
.await?;
_ => Err(internal_error("missing content type")),
}
}
match res
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
{
Some("application/json") => {
serde_json::from_slice::<RpcResponse>(&*res.bytes().await.map_err(internal_error)?)
.map_err(parse_error)?
.result
}
#[cfg(feature = "cbor")]
Some("application/cbor") => {
serde_cbor::from_slice::<RpcResponse>(&*res.bytes().await.map_err(internal_error)?)
.map_err(parse_error)?
.result
}
_ => Err(internal_error("missing content type")),
pub async fn call_remote_socket(
connection: impl AsyncRead + AsyncWrite,
method: &str,
params: Value,
) -> Result<Value, RpcError> {
let rpc_req = RpcRequest {
id: Some(Id::Number(0.into())),
method: GenericRpcMethod::new(method),
params,
};
let conn = connection;
tokio::pin!(conn);
let mut buf = serde_json::to_vec(&rpc_req).map_err(|e| RpcError {
data: Some(e.to_string().into()),
..yajrc::INTERNAL_ERROR
})?;
buf.push(b'\n');
conn.write_all(&buf).await.map_err(|e| RpcError {
data: Some(e.to_string().into()),
..yajrc::INTERNAL_ERROR
})?;
let mut line = String::new();
BufReader::new(conn).read_line(&mut line).await?;
serde_json::from_str::<RpcResponse>(&line)
.map_err(parse_error)?
.result
}
pub struct CallRemoteHandler<RemoteContext, RemoteHandler> {
_phantom: PhantomData<RemoteContext>,
handler: RemoteHandler,
}
impl<RemoteContext, RemoteHandler> CallRemoteHandler<RemoteContext, RemoteHandler> {
pub fn new(handler: RemoteHandler) -> Self {
Self {
_phantom: PhantomData,
handler: handler,
}
}
}
#[async_trait::async_trait]
impl<T> CliContext for T
where
T: CliContextHttp,
impl<RemoteContext, RemoteHandler: Clone> Clone
for CallRemoteHandler<RemoteContext, RemoteHandler>
{
async fn call_remote(&self, method: &str, params: Value) -> Result<Value, RpcError> {
<Self as CliContextHttp>::call_remote(&self, method, params).await
}
}
#[async_trait::async_trait]
pub trait CliContextSocket: crate::Context {
type Stream: AsyncRead + AsyncWrite + Send;
async fn connect(&self) -> std::io::Result<Self::Stream>;
async fn call_remote(&self, method: &str, params: Value) -> Result<Value, RpcError> {
let rpc_req = RpcRequest {
id: Some(Id::Number(0.into())),
method: GenericRpcMethod::new(method),
params,
};
let conn = self.connect().await.map_err(|e| RpcError {
data: Some(e.to_string().into()),
..yajrc::INTERNAL_ERROR
})?;
tokio::pin!(conn);
let mut buf = serde_json::to_vec(&rpc_req).map_err(|e| RpcError {
data: Some(e.to_string().into()),
..yajrc::INTERNAL_ERROR
})?;
buf.push(b'\n');
conn.write_all(&buf).await.map_err(|e| RpcError {
data: Some(e.to_string().into()),
..yajrc::INTERNAL_ERROR
})?;
let mut line = String::new();
BufReader::new(conn).read_line(&mut line).await?;
serde_json::from_str::<RpcResponse>(&line)
.map_err(parse_error)?
.result
}
}
#[derive(Debug, Default)]
pub struct CallRemote<RemoteContext, RemoteHandler>(PhantomData<(RemoteContext, RemoteHandler)>);
impl<RemoteContext, RemoteHandler> CallRemote<RemoteContext, RemoteHandler> {
pub fn new() -> Self {
Self(PhantomData)
}
}
impl<RemoteContext, RemoteHandler> Clone for CallRemote<RemoteContext, RemoteHandler> {
fn clone(&self) -> Self {
Self(PhantomData)
Self {
_phantom: PhantomData,
handler: self.handler.clone(),
}
}
}
#[async_trait::async_trait]
impl<Context: CliContext, RemoteContext, RemoteHandler> Handler<Context>
for CallRemote<RemoteContext, RemoteHandler>
impl<RemoteContext, RemoteHandler> std::fmt::Debug
for CallRemoteHandler<RemoteContext, RemoteHandler>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("CallRemoteHandler").finish()
}
}
impl<RemoteContext, RemoteHandler> HandlerTypes for CallRemoteHandler<RemoteContext, RemoteHandler>
where
RemoteContext: IntoContext,
RemoteHandler: Handler<RemoteContext>,
RemoteHandler: HandlerTypes,
RemoteHandler::Params: Serialize,
RemoteHandler::InheritedParams: Serialize,
RemoteHandler::Ok: DeserializeOwned,
@@ -204,6 +207,18 @@ where
type InheritedParams = RemoteHandler::InheritedParams;
type Ok = RemoteHandler::Ok;
type Err = RemoteHandler::Err;
}
#[async_trait::async_trait]
impl<Context: CallRemote, RemoteContext, RemoteHandler> Handler<Context>
for CallRemoteHandler<RemoteContext, RemoteHandler>
where
RemoteContext: IntoContext,
RemoteHandler: Handler<RemoteContext>,
RemoteHandler::Params: Serialize,
RemoteHandler::InheritedParams: Serialize,
RemoteHandler::Ok: DeserializeOwned,
RemoteHandler::Err: From<RpcError>,
{
async fn handle_async(
&self,
handle_args: HandleArgs<Context, Self>,
@@ -229,3 +244,49 @@ where
}
}
}
// #[async_trait::async_trait]
impl<Context: CallRemote, RemoteContext, RemoteHandler> CliBindings<Context>
for CallRemoteHandler<RemoteContext, RemoteHandler>
where
RemoteContext: IntoContext,
RemoteHandler: Handler<RemoteContext> + CliBindings<Context>,
RemoteHandler::Params: Serialize,
RemoteHandler::InheritedParams: Serialize,
RemoteHandler::Ok: DeserializeOwned,
RemoteHandler::Err: From<RpcError>,
{
fn cli_command(&self, ctx_ty: TypeId) -> clap::Command {
self.handler.cli_command(ctx_ty)
}
fn cli_parse(
&self,
matches: &clap::ArgMatches,
ctx_ty: TypeId,
) -> Result<(std::collections::VecDeque<&'static str>, Value), clap::Error> {
self.handler.cli_parse(matches, ctx_ty)
}
fn cli_display(
&self,
HandleArgs {
context,
parent_method,
method,
params,
inherited_params,
raw_params,
}: HandleArgs<Context, Self>,
result: Self::Ok,
) -> Result<(), Self::Err> {
self.handler.cli_display(
HandleArgs {
context,
parent_method,
method,
params,
inherited_params,
raw_params,
},
result,
)
}
}

View File

@@ -0,0 +1,33 @@
use std::fmt::Display;
use std::io::Stdin;
use std::str::FromStr;
use clap::ArgMatches;
pub use {clap, serde};
pub fn default_arg_parser<T>(arg: &str, _: &ArgMatches) -> Result<T, clap::Error>
where
T: FromStr,
T::Err: Display,
{
arg.parse()
.map_err(|e| clap::Error::raw(clap::error::ErrorKind::ValueValidation, e))
}
pub fn default_stdin_parser<T>(stdin: &mut Stdin, _: &ArgMatches) -> Result<T, clap::Error>
where
T: FromStr,
T::Err: Display,
{
let mut s = String::new();
stdin
.read_line(&mut s)
.map_err(|e| clap::Error::raw(clap::error::ErrorKind::Io, e))?;
if let Some(s) = s.strip_suffix("\n") {
s
} else {
&s
}
.parse()
.map_err(|e| clap::Error::raw(clap::error::ErrorKind::ValueValidation, e))
}

View File

@@ -6,6 +6,9 @@ use tokio::runtime::Handle;
use crate::Handler;
pub trait Context: Any + Send + Sync + 'static {
fn inner_type_id(&self) -> TypeId {
<Self as Any>::type_id(&self)
}
fn runtime(&self) -> Handle {
Handle::current()
}
@@ -36,7 +39,7 @@ impl<C: Context + Sized> IntoContext for C {
AnyContext::new(self)
}
fn downcast(value: AnyContext) -> Result<Self, AnyContext> {
if value.0.type_id() == TypeId::of::<C>() {
if value.0.inner_type_id() == TypeId::of::<C>() {
unsafe { Ok(value.downcast_unchecked::<C>()) }
} else {
Err(value)
@@ -90,10 +93,8 @@ impl AnyContext {
Self(Box::new(value))
}
unsafe fn downcast_unchecked<C: Context>(self) -> C {
unsafe {
let raw: *mut dyn Context = Box::into_raw(self.0);
*Box::from_raw(raw as *mut C)
}
let raw: *mut dyn Context = Box::into_raw(self.0);
*Box::from_raw(raw as *mut C)
}
}
@@ -105,7 +106,7 @@ impl IntoContext for AnyContext {
None
}
fn inner_type_id(&self) -> TypeId {
self.0.type_id()
self.0.inner_type_id()
}
fn downcast(value: AnyContext) -> Result<Self, AnyContext> {
Ok(value)

View File

@@ -14,6 +14,7 @@ use yajrc::RpcError;
use crate::context::{AnyContext, IntoContext};
use crate::util::{combine, internal_error, invalid_params, Flat};
use crate::{CallRemote, CallRemoteHandler};
pub(crate) struct HandleAnyArgs {
pub(crate) context: AnyContext,
@@ -24,7 +25,7 @@ pub(crate) struct HandleAnyArgs {
impl HandleAnyArgs {
fn downcast<Context: IntoContext, H>(self) -> Result<HandleArgs<Context, H>, imbl_value::Error>
where
H: Handler<Context>,
H: HandlerTypes,
H::Params: DeserializeOwned,
H::InheritedParams: DeserializeOwned,
{
@@ -49,7 +50,7 @@ impl HandleAnyArgs {
}
#[async_trait::async_trait]
pub(crate) trait HandleAny: Send + Sync {
pub(crate) trait HandleAny: std::fmt::Debug + Send + Sync {
fn handle_sync(&self, handle_args: HandleAnyArgs) -> Result<Value, RpcError>;
async fn handle_async(&self, handle_args: HandleAnyArgs) -> Result<Value, RpcError>;
fn method_from_dots(&self, method: &str, ctx_ty: TypeId) -> Option<VecDeque<&'static str>>;
@@ -77,7 +78,7 @@ pub(crate) trait CliBindingsAny {
fn cli_display(&self, handle_args: HandleAnyArgs, result: Value) -> Result<(), RpcError>;
}
pub trait CliBindings<Context: IntoContext>: Handler<Context> {
pub trait CliBindings<Context: IntoContext>: HandlerTypes {
fn cli_command(&self, ctx_ty: TypeId) -> Command;
fn cli_parse(
&self,
@@ -91,7 +92,7 @@ pub trait CliBindings<Context: IntoContext>: Handler<Context> {
) -> Result<(), Self::Err>;
}
pub trait PrintCliResult<Context: IntoContext>: Handler<Context> {
pub trait PrintCliResult<Context: IntoContext>: HandlerTypes {
fn print(
&self,
handle_args: HandleArgs<Context, Self>,
@@ -102,7 +103,7 @@ pub trait PrintCliResult<Context: IntoContext>: Handler<Context> {
pub(crate) trait HandleAnyWithCli: HandleAny + CliBindingsAny {}
impl<T: HandleAny + CliBindingsAny> HandleAnyWithCli for T {}
#[derive(Clone)]
#[derive(Debug, Clone)]
pub(crate) enum DynHandler {
WithoutCli(Arc<dyn HandleAny>),
WithCli(Arc<dyn HandleAnyWithCli>),
@@ -130,7 +131,7 @@ impl HandleAny for DynHandler {
}
#[derive(Debug, Clone)]
pub struct HandleArgs<Context: IntoContext, H: Handler<Context> + ?Sized> {
pub struct HandleArgs<Context: IntoContext, H: HandlerTypes + ?Sized> {
pub context: Context,
pub parent_method: Vec<&'static str>,
pub method: VecDeque<&'static str>,
@@ -139,12 +140,17 @@ pub struct HandleArgs<Context: IntoContext, H: Handler<Context> + ?Sized> {
pub raw_params: Value,
}
#[async_trait::async_trait]
pub trait Handler<Context: IntoContext>: Clone + Send + Sync + 'static {
pub trait HandlerTypes {
type Params: Send + Sync;
type InheritedParams: Send + Sync;
type Ok: Send + Sync;
type Err: Send + Sync;
}
#[async_trait::async_trait]
pub trait Handler<Context: IntoContext>:
HandlerTypes + std::fmt::Debug + Clone + Send + Sync + 'static
{
fn handle_sync(&self, handle_args: HandleArgs<Context, Self>) -> Result<Self::Ok, Self::Err> {
handle_args
.context
@@ -198,6 +204,11 @@ impl<Context, H> AnyHandler<Context, H> {
}
}
}
impl<Context, H: std::fmt::Debug> std::fmt::Debug for AnyHandler<Context, H> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("AnyHandler").field(&self.handler).finish()
}
}
#[async_trait::async_trait]
impl<Context: IntoContext, H: Handler<Context>> HandleAny for AnyHandler<Context, H>
@@ -271,7 +282,7 @@ impl<'a> std::borrow::Borrow<Option<&'a str>> for Name {
}
}
#[derive(Clone)]
#[derive(Debug, Clone)]
pub(crate) struct SubcommandMap(pub(crate) BTreeMap<Name, BTreeMap<Option<TypeId>, DynHandler>>);
impl SubcommandMap {
fn insert(
@@ -324,6 +335,13 @@ impl<Params, InheritedParams> Clone for ParentHandler<Params, InheritedParams> {
}
}
}
impl<Params, InheritedParams> std::fmt::Debug for ParentHandler<Params, InheritedParams> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("ParentHandler")
.field(&self.subcommands)
.finish()
}
}
struct InheritanceHandler<Context, Params, InheritedParams, H, F> {
_phantom: PhantomData<(Context, Params, InheritedParams)>,
@@ -341,6 +359,28 @@ impl<Context, Params, InheritedParams, H: Clone, F: Clone> Clone
}
}
}
impl<Context, Params, InheritedParams, H: std::fmt::Debug, F> std::fmt::Debug
for InheritanceHandler<Context, Params, InheritedParams, H, F>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("InheritanceHandler")
.field(&self.handler)
.finish()
}
}
impl<Context, Params, InheritedParams, H, F> HandlerTypes
for InheritanceHandler<Context, Params, InheritedParams, H, F>
where
Context: IntoContext,
H: HandlerTypes,
Params: Send + Sync,
InheritedParams: Send + Sync,
{
type Params = H::Params;
type InheritedParams = Flat<Params, InheritedParams>;
type Ok = H::Ok;
type Err = H::Err;
}
#[async_trait::async_trait]
impl<Context, Params, InheritedParams, H, F> Handler<Context>
for InheritanceHandler<Context, Params, InheritedParams, H, F>
@@ -351,10 +391,6 @@ where
H: Handler<Context>,
F: Fn(Params, InheritedParams) -> H::InheritedParams + Send + Sync + Clone + 'static,
{
type Params = H::Params;
type InheritedParams = Flat<Params, InheritedParams>;
type Ok = H::Ok;
type Err = H::Err;
fn handle_sync(
&self,
HandleArgs {
@@ -446,7 +482,10 @@ impl<Params: Send + Sync, InheritedParams: Send + Sync> ParentHandler<Params, In
pub fn subcommand<Context, H>(mut self, name: &'static str, handler: H) -> Self
where
Context: IntoContext,
H: CliBindings<Context, InheritedParams = NoParams> + 'static,
H: HandlerTypes<InheritedParams = NoParams>
+ Handler<Context>
+ CliBindings<Context>
+ 'static,
H::Params: DeserializeOwned,
H::Ok: Serialize + DeserializeOwned,
RpcError: From<H::Err>,
@@ -454,10 +493,42 @@ impl<Params: Send + Sync, InheritedParams: Send + Sync> ParentHandler<Params, In
self.subcommands.insert(
handler.contexts(),
name.into(),
DynHandler::WithCli(Arc::new(AnyHandler {
_ctx: PhantomData,
handler,
})),
DynHandler::WithCli(Arc::new(AnyHandler::new(handler))),
);
self
}
pub fn subcommand_remote_cli<CliContext, ServerContext, H>(
mut self,
name: &'static str,
handler: H,
) -> Self
where
ServerContext: IntoContext,
CliContext: IntoContext + CallRemote,
H: HandlerTypes<InheritedParams = NoParams>
+ Handler<ServerContext>
+ CliBindings<CliContext>
+ 'static,
H::Params: Serialize + DeserializeOwned,
H::Ok: Serialize + DeserializeOwned,
RpcError: From<H::Err>,
H::Err: From<RpcError>,
CallRemoteHandler<ServerContext, H>: Handler<CliContext>,
<CallRemoteHandler<ServerContext, H> as HandlerTypes>::Ok: Serialize + DeserializeOwned,
<CallRemoteHandler<ServerContext, H> as HandlerTypes>::Params: Serialize + DeserializeOwned,
<CallRemoteHandler<ServerContext, H> as HandlerTypes>::InheritedParams: DeserializeOwned,
RpcError: From<<CallRemoteHandler<ServerContext, H> as HandlerTypes>::Err>,
{
self.subcommands.insert(
handler.contexts(),
name.into(),
DynHandler::WithoutCli(Arc::new(AnyHandler::new(handler.clone()))),
);
let call_remote = CallRemoteHandler::<ServerContext, H>::new(handler);
self.subcommands.insert(
call_remote.contexts(),
name.into(),
DynHandler::WithCli(Arc::new(AnyHandler::new(call_remote))),
);
self
}
@@ -472,10 +543,7 @@ impl<Params: Send + Sync, InheritedParams: Send + Sync> ParentHandler<Params, In
self.subcommands.insert(
handler.contexts(),
name.into(),
DynHandler::WithoutCli(Arc::new(AnyHandler {
_ctx: PhantomData,
handler,
})),
DynHandler::WithoutCli(Arc::new(AnyHandler::new(handler))),
);
self
}
@@ -493,7 +561,7 @@ where
) -> Self
where
Context: IntoContext,
H: CliBindings<Context> + 'static,
H: Handler<Context> + CliBindings<Context> + 'static,
H::Params: DeserializeOwned,
H::Ok: Serialize + DeserializeOwned,
RpcError: From<H::Err>,
@@ -513,6 +581,62 @@ where
);
self
}
pub fn subcommand_with_inherited_remote_cli<CliContext, ServerContext, H, F>(
mut self,
name: &'static str,
handler: H,
inherit: F,
) -> Self
where
ServerContext: IntoContext,
CliContext: IntoContext + CallRemote,
H: HandlerTypes + Handler<ServerContext> + CliBindings<CliContext> + 'static,
H::Params: Serialize + DeserializeOwned,
H::Ok: Serialize + DeserializeOwned,
RpcError: From<H::Err>,
H::Err: From<RpcError>,
CallRemoteHandler<ServerContext, H>:
Handler<CliContext, InheritedParams = H::InheritedParams>,
<CallRemoteHandler<ServerContext, H> as HandlerTypes>::Ok: Serialize + DeserializeOwned,
<CallRemoteHandler<ServerContext, H> as HandlerTypes>::Params: Serialize + DeserializeOwned,
<CallRemoteHandler<ServerContext, H> as HandlerTypes>::InheritedParams:
Serialize + DeserializeOwned,
RpcError: From<<CallRemoteHandler<ServerContext, H> as HandlerTypes>::Err>,
F: Fn(Params, InheritedParams) -> H::InheritedParams + Send + Sync + Clone + 'static,
{
self.subcommands.insert(
handler.contexts(),
name.into(),
DynHandler::WithoutCli(Arc::new(AnyHandler::new(InheritanceHandler::<
ServerContext,
Params,
InheritedParams,
H,
F,
> {
_phantom: PhantomData,
handler: handler.clone(),
inherit: inherit.clone(),
}))),
);
let call_remote = CallRemoteHandler::<ServerContext, H>::new(handler);
self.subcommands.insert(
call_remote.contexts(),
name.into(),
DynHandler::WithCli(Arc::new(AnyHandler::new(InheritanceHandler::<
CliContext,
Params,
InheritedParams,
CallRemoteHandler<ServerContext, H>,
F,
> {
_phantom: PhantomData,
handler: call_remote,
inherit,
}))),
);
self
}
pub fn subcommand_with_inherited_no_cli<Context, H, F>(
mut self,
name: &'static str,
@@ -544,7 +668,7 @@ where
pub fn root_handler<Context, H, F>(mut self, handler: H, inherit: F) -> Self
where
Context: IntoContext,
H: CliBindings<Context, Params = NoParams> + 'static,
H: HandlerTypes<Params = NoParams> + Handler<Context> + CliBindings<Context> + 'static,
H::Params: DeserializeOwned,
H::Ok: Serialize + DeserializeOwned,
RpcError: From<H::Err>,
@@ -589,17 +713,23 @@ where
}
}
#[async_trait::async_trait]
impl<
Context: IntoContext,
Params: Serialize + Send + Sync + 'static,
InheritedParams: Serialize + Send + Sync + 'static,
> Handler<Context> for ParentHandler<Params, InheritedParams>
impl<Params, InheritedParams> HandlerTypes for ParentHandler<Params, InheritedParams>
where
Params: Send + Sync,
InheritedParams: Send + Sync,
{
type Params = Params;
type InheritedParams = InheritedParams;
type Ok = Value;
type Err = RpcError;
}
#[async_trait::async_trait]
impl<Context, Params, InheritedParams> Handler<Context> for ParentHandler<Params, InheritedParams>
where
Context: IntoContext,
Params: Serialize + Send + Sync + 'static,
InheritedParams: Serialize + Send + Sync + 'static,
{
fn handle_sync(
&self,
HandleArgs {
@@ -614,7 +744,7 @@ impl<
if let Some(cmd) = cmd {
parent_method.push(cmd);
}
if let Some((_, sub_handler)) = self.subcommands.get(context.inner_type_id(), cmd) {
if let Some((_, sub_handler)) = &self.subcommands.get(context.inner_type_id(), cmd) {
sub_handler.handle_sync(HandleAnyArgs {
context: context.upcast(),
parent_method,
@@ -776,11 +906,18 @@ impl<F: Clone, T, E, Args> Clone for FromFn<F, T, E, Args> {
}
}
}
impl<F, T, E, Args> std::fmt::Debug for FromFn<F, T, E, Args> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FromFn")
.field("blocking", &self.blocking)
.finish()
}
}
impl<Context, F, T, E, Args> PrintCliResult<Context> for FromFn<F, T, E, Args>
where
Context: IntoContext,
Self: Handler<Context>,
<Self as Handler<Context>>::Ok: Display,
Self: HandlerTypes,
<Self as HandlerTypes>::Ok: Display,
{
fn print(&self, _: HandleArgs<Context, Self>, result: Self::Ok) -> Result<(), Self::Err> {
Ok(println!("{result}"))
@@ -815,11 +952,16 @@ impl<F: Clone, Fut, T, E, Args> Clone for FromFnAsync<F, Fut, T, E, Args> {
}
}
}
impl<F, Fut, T, E, Args> std::fmt::Debug for FromFnAsync<F, Fut, T, E, Args> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FromFnAsync").finish()
}
}
impl<Context, F, Fut, T, E, Args> PrintCliResult<Context> for FromFnAsync<F, Fut, T, E, Args>
where
Context: IntoContext,
Self: Handler<Context>,
<Self as Handler<Context>>::Ok: Display,
Self: HandlerTypes,
<Self as HandlerTypes>::Ok: Display,
{
fn print(&self, _: HandleArgs<Context, Self>, result: Self::Ok) -> Result<(), Self::Err> {
Ok(println!("{result}"))
@@ -833,10 +975,8 @@ pub fn from_fn_async<F, Fut, T, E, Args>(function: F) -> FromFnAsync<F, Fut, T,
}
}
#[async_trait::async_trait]
impl<Context, F, T, E> Handler<Context> for FromFn<F, T, E, ()>
impl<F, T, E> HandlerTypes for FromFn<F, T, E, ()>
where
Context: IntoContext,
F: Fn() -> Result<T, E> + Send + Sync + Clone + 'static,
T: Send + Sync + 'static,
E: Send + Sync + 'static,
@@ -845,6 +985,15 @@ where
type InheritedParams = NoParams;
type Ok = T;
type Err = E;
}
#[async_trait::async_trait]
impl<Context, F, T, E> Handler<Context> for FromFn<F, T, E, ()>
where
Context: IntoContext,
F: Fn() -> Result<T, E> + Send + Sync + Clone + 'static,
T: Send + Sync + 'static,
E: Send + Sync + 'static,
{
fn handle_sync(&self, _: HandleArgs<Context, Self>) -> Result<Self::Ok, Self::Err> {
(self.function)()
}
@@ -859,10 +1008,8 @@ where
}
}
}
#[async_trait::async_trait]
impl<Context, F, Fut, T, E> Handler<Context> for FromFnAsync<F, Fut, T, E, ()>
impl<F, Fut, T, E> HandlerTypes for FromFnAsync<F, Fut, T, E, ()>
where
Context: IntoContext,
F: Fn() -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<T, E>> + Send + Sync + 'static,
T: Send + Sync + 'static,
@@ -872,13 +1019,22 @@ where
type InheritedParams = NoParams;
type Ok = T;
type Err = E;
}
#[async_trait::async_trait]
impl<Context, F, Fut, T, E> Handler<Context> for FromFnAsync<F, Fut, T, E, ()>
where
Context: IntoContext,
F: Fn() -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<T, E>> + Send + Sync + 'static,
T: Send + Sync + 'static,
E: Send + Sync + 'static,
{
async fn handle_async(&self, _: HandleArgs<Context, Self>) -> Result<Self::Ok, Self::Err> {
(self.function)().await
}
}
#[async_trait::async_trait]
impl<Context, F, T, E> Handler<Context> for FromFn<F, T, E, (Context,)>
impl<Context, F, T, E> HandlerTypes for FromFn<F, T, E, (Context,)>
where
Context: IntoContext,
F: Fn(Context) -> Result<T, E> + Send + Sync + Clone + 'static,
@@ -889,6 +1045,15 @@ where
type InheritedParams = NoParams;
type Ok = T;
type Err = E;
}
#[async_trait::async_trait]
impl<Context, F, T, E> Handler<Context> for FromFn<F, T, E, (Context,)>
where
Context: IntoContext,
F: Fn(Context) -> Result<T, E> + Send + Sync + Clone + 'static,
T: Send + Sync + 'static,
E: Send + Sync + 'static,
{
fn handle_sync(&self, handle_args: HandleArgs<Context, Self>) -> Result<Self::Ok, Self::Err> {
(self.function)(handle_args.context)
}
@@ -903,8 +1068,7 @@ where
}
}
}
#[async_trait::async_trait]
impl<Context, F, Fut, T, E> Handler<Context> for FromFnAsync<F, Fut, T, E, (Context,)>
impl<Context, F, Fut, T, E> HandlerTypes for FromFnAsync<F, Fut, T, E, (Context,)>
where
Context: IntoContext,
F: Fn(Context) -> Fut + Send + Sync + Clone + 'static,
@@ -916,6 +1080,16 @@ where
type InheritedParams = NoParams;
type Ok = T;
type Err = E;
}
#[async_trait::async_trait]
impl<Context, F, Fut, T, E> Handler<Context> for FromFnAsync<F, Fut, T, E, (Context,)>
where
Context: IntoContext,
F: Fn(Context) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<T, E>> + Send + Sync + 'static,
T: Send + Sync + 'static,
E: Send + Sync + 'static,
{
async fn handle_async(
&self,
handle_args: HandleArgs<Context, Self>,
@@ -924,8 +1098,7 @@ where
}
}
#[async_trait::async_trait]
impl<Context, F, T, E, Params> Handler<Context> for FromFn<F, T, E, (Context, Params)>
impl<Context, F, T, E, Params> HandlerTypes for FromFn<F, T, E, (Context, Params)>
where
Context: IntoContext,
F: Fn(Context, Params) -> Result<T, E> + Send + Sync + Clone + 'static,
@@ -937,6 +1110,16 @@ where
type InheritedParams = NoParams;
type Ok = T;
type Err = E;
}
#[async_trait::async_trait]
impl<Context, F, T, E, Params> Handler<Context> for FromFn<F, T, E, (Context, Params)>
where
Context: IntoContext,
F: Fn(Context, Params) -> Result<T, E> + Send + Sync + Clone + 'static,
Params: DeserializeOwned + Send + Sync + 'static,
T: Send + Sync + 'static,
E: Send + Sync + 'static,
{
fn handle_sync(&self, handle_args: HandleArgs<Context, Self>) -> Result<Self::Ok, Self::Err> {
let HandleArgs {
context, params, ..
@@ -954,9 +1137,8 @@ where
}
}
}
#[async_trait::async_trait]
impl<Context, F, Fut, T, E, Params> Handler<Context>
for FromFnAsync<F, Fut, T, E, (Context, Params)>
impl<Context, F, Fut, T, E, Params> HandlerTypes for FromFnAsync<F, Fut, T, E, (Context, Params)>
where
Context: IntoContext,
F: Fn(Context, Params) -> Fut + Send + Sync + Clone + 'static,
@@ -969,6 +1151,18 @@ where
type InheritedParams = NoParams;
type Ok = T;
type Err = E;
}
#[async_trait::async_trait]
impl<Context, F, Fut, T, E, Params> Handler<Context>
for FromFnAsync<F, Fut, T, E, (Context, Params)>
where
Context: IntoContext,
F: Fn(Context, Params) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<T, E>> + Send + Sync + 'static,
Params: DeserializeOwned + Send + Sync + 'static,
T: Send + Sync + 'static,
E: Send + Sync + 'static,
{
async fn handle_async(
&self,
handle_args: HandleArgs<Context, Self>,
@@ -980,8 +1174,7 @@ where
}
}
#[async_trait::async_trait]
impl<Context, F, T, E, Params, InheritedParams> Handler<Context>
impl<Context, F, T, E, Params, InheritedParams> HandlerTypes
for FromFn<F, T, E, (Context, Params, InheritedParams)>
where
Context: IntoContext,
@@ -995,6 +1188,18 @@ where
type InheritedParams = InheritedParams;
type Ok = T;
type Err = E;
}
#[async_trait::async_trait]
impl<Context, F, T, E, Params, InheritedParams> Handler<Context>
for FromFn<F, T, E, (Context, Params, InheritedParams)>
where
Context: IntoContext,
F: Fn(Context, Params, InheritedParams) -> Result<T, E> + Send + Sync + Clone + 'static,
Params: DeserializeOwned + Send + Sync + 'static,
InheritedParams: DeserializeOwned + Send + Sync + 'static,
T: Send + Sync + 'static,
E: Send + Sync + 'static,
{
fn handle_sync(&self, handle_args: HandleArgs<Context, Self>) -> Result<Self::Ok, Self::Err> {
let HandleArgs {
context,
@@ -1015,8 +1220,8 @@ where
}
}
}
#[async_trait::async_trait]
impl<Context, F, Fut, T, E, Params, InheritedParams> Handler<Context>
impl<Context, F, Fut, T, E, Params, InheritedParams> HandlerTypes
for FromFnAsync<F, Fut, T, E, (Context, Params, InheritedParams)>
where
Context: IntoContext,
@@ -1031,6 +1236,19 @@ where
type InheritedParams = InheritedParams;
type Ok = T;
type Err = E;
}
#[async_trait::async_trait]
impl<Context, F, Fut, T, E, Params, InheritedParams> Handler<Context>
for FromFnAsync<F, Fut, T, E, (Context, Params, InheritedParams)>
where
Context: IntoContext,
F: Fn(Context, Params, InheritedParams) -> Fut + Send + Sync + Clone + 'static,
Fut: Future<Output = Result<T, E>> + Send + Sync + 'static,
Params: DeserializeOwned + Send + Sync + 'static,
InheritedParams: DeserializeOwned + Send + Sync + 'static,
T: Send + Sync + 'static,
E: Send + Sync + 'static,
{
async fn handle_async(
&self,
handle_args: HandleArgs<Context, Self>,
@@ -1048,7 +1266,7 @@ where
impl<Context, F, T, E, Args> CliBindings<Context> for FromFn<F, T, E, Args>
where
Context: IntoContext,
Self: Handler<Context>,
Self: HandlerTypes,
Self::Params: FromArgMatches + CommandFactory + Serialize,
Self: PrintCliResult<Context>,
{
@@ -1097,7 +1315,7 @@ where
impl<Context, F, Fut, T, E, Args> CliBindings<Context> for FromFnAsync<F, Fut, T, E, Args>
where
Context: IntoContext,
Self: Handler<Context>,
Self: HandlerTypes,
Self::Params: FromArgMatches + CommandFactory + Serialize,
Self: PrintCliResult<Context>,
{

View File

@@ -28,6 +28,7 @@ pub use server::*;
pub use {clap, futures, hyper, reqwest, serde, serde_json, tokio, url, yajrc};
mod cli;
pub mod command_helpers;
mod context;
mod handler;
mod server;

View File

@@ -5,17 +5,17 @@ use futures::{Future, Stream, StreamExt, TryStreamExt};
use imbl_value::Value;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::net::{TcpListener, ToSocketAddrs, UnixListener};
use tokio::sync::OnceCell;
use tokio::sync::Notify;
use yajrc::RpcError;
use crate::util::{parse_error, JobRunner};
use crate::util::{parse_error, JobRunner, StreamUntil};
use crate::Server;
#[derive(Clone)]
pub struct ShutdownHandle(Arc<OnceCell<()>>);
pub struct ShutdownHandle(Arc<Notify>);
impl ShutdownHandle {
pub fn shutdown(self) {
let _ = self.0.set(());
self.0.notify_one();
}
}
@@ -25,10 +25,10 @@ impl<Context: crate::Context> Server<Context> {
listener: impl Stream<Item = std::io::Result<T>> + 'a,
error_handler: impl Fn(std::io::Error) + Sync + 'a,
) -> (ShutdownHandle, impl Future<Output = ()> + 'a) {
let shutdown = Arc::new(OnceCell::new());
let shutdown = Arc::new(Notify::new());
(ShutdownHandle(shutdown.clone()), async move {
let mut runner = JobRunner::<std::io::Result<()>>::new();
let jobs = listener.map(|pipe| async {
let jobs = StreamUntil::new(listener, shutdown.notified()).map(|pipe| async {
let pipe = pipe?;
let (r, mut w) = tokio::io::split(pipe);
let stream = self.stream(

View File

@@ -1,6 +1,7 @@
use std::fmt::Display;
use futures::future::BoxFuture;
use futures::future::{BoxFuture, FusedFuture};
use futures::stream::FusedStream;
use futures::{Future, FutureExt, Stream, StreamExt};
use imbl_value::Value;
use serde::de::DeserializeOwned;
@@ -148,6 +149,9 @@ impl<'a, T> JobRunner<'a, T> {
self.running.push(job.boxed());
} else {
self.closed = true;
if self.running.is_empty() {
return None;
}
}
}
res = self.next() => {
@@ -172,6 +176,52 @@ impl<'a, T> Stream for JobRunner<'a, T> {
}
}
#[pin_project::pin_project]
pub struct StreamUntil<S, F> {
#[pin]
stream: S,
#[pin]
until: F,
done: bool,
}
impl<S, F> StreamUntil<S, F> {
pub fn new(stream: S, until: F) -> Self {
Self {
stream,
until,
done: false,
}
}
}
impl<S, F> Stream for StreamUntil<S, F>
where
S: Stream,
F: Future,
{
type Item = S::Item;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.project();
*this.done = *this.done || this.until.poll(cx).is_ready();
if *this.done {
std::task::Poll::Ready(None)
} else {
this.stream.poll_next(cx)
}
}
}
impl<S, F> FusedStream for StreamUntil<S, F>
where
S: FusedStream,
F: FusedFuture,
{
fn is_terminated(&self) -> bool {
self.done || self.stream.is_terminated() || self.until.is_terminated()
}
}
// #[derive(Debug)]
// pub enum Infallible {}
// impl<T> From<Infallible> for T {

View File

@@ -1,119 +1,83 @@
// use std::fmt::Display;
// use std::str::FromStr;
// use std::sync::Arc;
use std::fmt::Display;
use std::str::FromStr;
use std::sync::Arc;
// use futures::FutureExt;
// use hyper::Request;
// use rpc_toolkit::clap::Arg;
// use rpc_toolkit::hyper::http::Error as HttpError;
// use rpc_toolkit::hyper::{Body, Response};
// use rpc_toolkit::rpc_server_helpers::{
// DynMiddlewareStage2, DynMiddlewareStage3, DynMiddlewareStage4,
// };
// use rpc_toolkit::serde::{Deserialize, Serialize};
// use rpc_toolkit::url::Host;
// use rpc_toolkit::yajrc::RpcError;
// use rpc_toolkit::{command, rpc_server, run_cli, Context, Metadata};
use futures::FutureExt;
use hyper::Request;
use rpc_toolkit::clap::Arg;
use rpc_toolkit::hyper::http::Error as HttpError;
use rpc_toolkit::hyper::Response;
use rpc_toolkit::serde::{Deserialize, Serialize};
use rpc_toolkit::url::Host;
use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{command, Context};
// #[derive(Debug, Clone)]
// pub struct AppState(Arc<ConfigSeed>);
// impl From<AppState> for () {
// fn from(_: AppState) -> Self {
// ()
// }
// }
#[derive(Debug, Clone)]
pub struct AppState(Arc<ConfigSeed>);
impl From<AppState> for () {
fn from(_: AppState) -> Self {
()
}
}
// #[derive(Debug)]
// pub struct ConfigSeed {
// host: Host,
// port: u16,
// }
#[derive(Debug)]
pub struct ConfigSeed {
host: Host,
port: u16,
}
// impl Context for AppState {
// type Metadata = ();
// }
impl Context for AppState {}
// fn test_string() -> String {
// "test".to_owned()
// }
#[command(
about = "Does the thing",
subcommands("dothething2::<U>", self(dothething_impl(async)))
)]
async fn dothething<U>(
#[context] _ctx: AppState,
#[arg(short = 'a')] arg1: Option<String>,
#[arg(short = 'b', default)] val: String,
#[arg(short = 'c', help = "I am the flag `c`!", default)] arg3: bool,
#[arg(stdin)] structured: U,
) -> Result<(Option<String>, String, bool, U), RpcError>
where
U: Serialize + for<'a> Deserialize<'a> + FromStr + Clone + 'static,
U::Err: Display,
{
Ok((arg1, val, arg3, structured))
}
// #[command(
// about = "Does the thing",
// subcommands("dothething2::<U, E>", self(dothething_impl(async)))
// )]
// async fn dothething<
// U: Serialize + for<'a> Deserialize<'a> + FromStr<Err = E> + Clone + 'static,
// E: Display,
// >(
// #[context] _ctx: AppState,
// #[arg(short = 'a')] arg1: Option<String>,
// #[arg(short = 'b', default = "test_string")] val: String,
// #[arg(short = 'c', help = "I am the flag `c`!", default)] arg3: bool,
// #[arg(stdin)] structured: U,
// ) -> Result<(Option<String>, String, bool, U), RpcError> {
// Ok((arg1, val, arg3, structured))
// }
async fn dothething_impl<U: Serialize>(
ctx: AppState,
parent_data: (Option<String>, String, bool, U),
) -> Result<String, RpcError> {
Ok(format!(
"{:?}, {:?}, {}, {}, {}",
ctx,
parent_data.0,
parent_data.1,
parent_data.2,
serde_json::to_string_pretty(&parent_data.3)?
))
}
// async fn dothething_impl<U: Serialize>(
// ctx: AppState,
// parent_data: (Option<String>, String, bool, U),
// ) -> Result<String, RpcError> {
// Ok(format!(
// "{:?}, {:?}, {}, {}, {}",
// ctx,
// parent_data.0,
// parent_data.1,
// parent_data.2,
// serde_json::to_string_pretty(&parent_data.3)?
// ))
// }
// #[command(about = "Does the thing")]
// fn dothething2<U: Serialize + for<'a> Deserialize<'a> + FromStr<Err = E>, E: Display>(
// #[parent_data] parent_data: (Option<String>, String, bool, U),
// #[arg(stdin)] structured2: U,
// ) -> Result<String, RpcError> {
// Ok(format!(
// "{:?}, {}, {}, {}, {}",
// parent_data.0,
// parent_data.1,
// parent_data.2,
// serde_json::to_string_pretty(&parent_data.3)?,
// serde_json::to_string_pretty(&structured2)?,
// ))
// }
// async fn cors<M: Metadata + 'static>(
// req: &mut Request<Body>,
// _: M,
// ) -> Result<Result<DynMiddlewareStage2, Response<Body>>, HttpError> {
// if req.method() == hyper::Method::OPTIONS {
// Ok(Err(Response::builder()
// .header("Access-Control-Allow-Origin", "*")
// .body(Body::empty())?))
// } else {
// Ok(Ok(Box::new(|_, _| {
// async move {
// let res: DynMiddlewareStage3 = Box::new(|_, _| {
// async move {
// let res: DynMiddlewareStage4 = Box::new(|res| {
// async move {
// res.headers_mut()
// .insert("Access-Control-Allow-Origin", "*".parse()?);
// Ok::<_, HttpError>(())
// }
// .boxed()
// });
// Ok::<_, HttpError>(Ok(res))
// }
// .boxed()
// });
// Ok::<_, HttpError>(Ok(res))
// }
// .boxed()
// })))
// }
// }
#[command(about = "Does the thing")]
fn dothething2<U>(
#[parent_data] parent_data: (Option<String>, String, bool, U),
#[arg(stdin)] structured2: U,
) -> Result<String, RpcError>
where
U: Serialize + for<'a> Deserialize<'a> + FromStr + Clone + 'static,
U::Err: Display,
{
Ok(format!(
"{:?}, {}, {}, {}, {}",
parent_data.0,
parent_data.1,
parent_data.2,
serde_json::to_string_pretty(&parent_data.3)?,
serde_json::to_string_pretty(&structured2)?,
))
}
// #[tokio::test]
// async fn test_rpc() {

View File

@@ -1,13 +1,19 @@
use std::any::TypeId;
use std::ffi::OsString;
use std::fmt::Display;
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use clap::Parser;
use rpc_toolkit::{from_fn, from_fn_async, AnyContext, CliApp, Context, NoParams, ParentHandler};
use futures::future::ready;
use imbl_value::Value;
use rpc_toolkit::{
call_remote_socket, from_fn, from_fn_async, AnyContext, CallRemote, CliApp, Context, NoParams,
ParentHandler, Server,
};
use serde::{Deserialize, Serialize};
use tokio::runtime::{Handle, Runtime};
use tokio::sync::OnceCell;
use url::Url;
use tokio::sync::{Mutex, OnceCell};
use yajrc::RpcError;
#[derive(Parser, Deserialize)]
@@ -18,7 +24,9 @@ use yajrc::RpcError;
about = "This is a test cli application."
)]
struct CliConfig {
host: Option<String>,
#[arg(long = "host")]
host: Option<PathBuf>,
#[arg(short = 'c', long = "config")]
config: Option<PathBuf>,
}
impl CliConfig {
@@ -40,7 +48,7 @@ impl CliConfig {
}
struct CliContextSeed {
host: Url,
host: PathBuf,
rt: OnceCell<Runtime>,
}
#[derive(Clone)]
@@ -53,6 +61,17 @@ impl Context for CliContext {
self.0.rt.get().unwrap().handle().clone()
}
}
#[async_trait::async_trait]
impl CallRemote for CliContext {
async fn call_remote(&self, method: &str, params: Value) -> Result<Value, RpcError> {
call_remote_socket(
tokio::net::UnixStream::connect(&self.0.host).await.unwrap(),
method,
params,
)
.await
}
}
fn make_cli() -> CliApp<CliContext, CliConfig> {
CliApp::new(
@@ -61,8 +80,7 @@ fn make_cli() -> CliApp<CliContext, CliConfig> {
Ok(CliContext(Arc::new(CliContextSeed {
host: config
.host
.map(|h| h.parse().unwrap())
.unwrap_or_else(|| "http://localhost:8080/rpc".parse().unwrap()),
.unwrap_or_else(|| Path::new("./rpc.sock").to_owned()),
rt: OnceCell::new(),
})))
},
@@ -70,16 +88,30 @@ fn make_cli() -> CliApp<CliContext, CliConfig> {
)
}
struct ServerContextSeed {
state: Mutex<Value>,
}
#[derive(Clone)]
struct ServerContext(Arc<ServerContextSeed>);
impl Context for ServerContext {}
fn make_server() -> Server<ServerContext> {
let ctx = ServerContext(Arc::new(ServerContextSeed {
state: Mutex::new(Value::Null),
}));
Server::new(move || ready(Ok(ctx.clone())), make_api())
}
fn make_api() -> ParentHandler {
impl CliContext {
fn host(&self) -> &Url {
&self.0.host
}
}
async fn a_hello(_: CliContext) -> Result<String, RpcError> {
Ok::<_, RpcError>("Async Subcommand".to_string())
}
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
struct EchoParams {
next: String,
}
#[derive(Debug, Clone, Deserialize, Serialize, Parser)]
struct HelloParams {
whom: String,
}
@@ -88,29 +120,20 @@ fn make_api() -> ParentHandler {
donde: String,
}
ParentHandler::new()
.subcommand(
.subcommand_remote_cli::<CliContext, _, _>(
"echo",
ParentHandler::<NoParams>::new()
.subcommand_no_cli(
"echo_no_cli",
from_fn(|c: CliContext| {
Ok::<_, RpcError>(
format!("Subcommand No Cli: Host {host}", host = c.host()).to_string(),
)
}),
)
.subcommand_no_cli(
"echo_cli",
from_fn(|c: CliContext| {
Ok::<_, RpcError>(
format!("Subcommand Cli: Host {host}", host = c.host()).to_string(),
)
}),
),
from_fn_async(
|c: ServerContext, EchoParams { next }: EchoParams| async move {
Ok::<_, RpcError>(std::mem::replace(
&mut *c.0.state.lock().await,
Value::String(Arc::new(next)),
))
},
),
)
.subcommand(
"hello",
from_fn(|_: CliContext, HelloParams { whom }: HelloParams| {
from_fn(|_: AnyContext, HelloParams { whom }: HelloParams| {
Ok::<_, RpcError>(format!("Hello {whom}").to_string())
}),
)
@@ -123,7 +146,7 @@ fn make_api() -> ParentHandler {
Ok::<_, RpcError>(
format!(
"Subcommand No Cli: Host {host} Donde = {donde}",
host = c.host()
host = c.0.host.display()
)
.to_string(),
)
@@ -137,8 +160,8 @@ fn make_api() -> ParentHandler {
from_fn(|c: CliContext, _, InheritParams { donde }| {
Ok::<_, RpcError>(
format!(
"Subcommand No Cli: Host {host} Donde = {donde}",
host = c.host(),
"Root Command: Host {host} Donde = {donde}",
host = c.0.host.display(),
)
.to_string(),
)
@@ -167,3 +190,63 @@ pub fn internal_error(e: impl Display) -> RpcError {
..yajrc::INTERNAL_ERROR
}
}
#[test]
fn test_cli() {
make_cli()
.run(
["test-cli", "hello", "me"]
.into_iter()
.map(|s| OsString::from(s)),
)
.unwrap();
make_cli()
.run(
["test-cli", "fizz", "buzz"]
.into_iter()
.map(|s| OsString::from(s)),
)
.unwrap();
}
#[tokio::test]
async fn test_server() {
let path = Path::new(env!("CARGO_TARGET_TMPDIR")).join("rpc.sock");
tokio::fs::remove_file(&path).await.unwrap_or_default();
let server = make_server();
let (shutdown, fut) = server
.run_unix(path.clone(), |err| eprintln!("IO Error: {err}"))
.unwrap();
tokio::join!(
tokio::task::spawn_blocking(move || {
make_cli()
.run(
[
"test-cli",
&format!("--host={}", path.display()),
"echo",
"foo",
]
.into_iter()
.map(|s| OsString::from(s)),
)
.unwrap();
make_cli()
.run(
[
"test-cli",
&format!("--host={}", path.display()),
"echo",
"bar",
]
.into_iter()
.map(|s| OsString::from(s)),
)
.unwrap();
shutdown.shutdown()
}),
fut
)
.0
.unwrap();
}