From 3195ffab6845b52da73d4e4a8bb1a0ef7eef3bbf Mon Sep 17 00:00:00 2001 From: Aiden McClelland Date: Tue, 31 Aug 2021 19:16:40 -0600 Subject: [PATCH] overhaul context --- .../src/command/build.rs | 208 +++++++++++++----- .../src/command/mod.rs | 1 + .../src/command/parse.rs | 52 ++++- .../src/rpc_server/build.rs | 15 +- .../src/rpc_server/mod.rs | 1 + .../src/rpc_server/parse.rs | 5 + .../src/run_cli/build.rs | 8 + .../src/run_cli/mod.rs | 1 + .../src/run_cli/parse.rs | 9 + rpc-toolkit/src/context.rs | 6 + rpc-toolkit/tests/test.rs | 77 +++---- 11 files changed, 282 insertions(+), 101 deletions(-) diff --git a/rpc-toolkit-macro-internals/src/command/build.rs b/rpc-toolkit-macro-internals/src/command/build.rs index 205e332..46a87fe 100644 --- a/rpc-toolkit-macro-internals/src/command/build.rs +++ b/rpc-toolkit-macro-internals/src/command/build.rs @@ -5,11 +5,29 @@ use quote::*; use syn::fold::Fold; use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::token::{Comma, Where}; +use syn::token::{Add, Comma, Where}; use super::parse::*; use super::*; +fn ctx_trait(ctx_ty: Type, opt: &Options) -> TokenStream { + let mut bounds: Punctuated = Punctuated::new(); + bounds.push(macro_try!(parse2(quote! { Into<#ctx_ty> }))); + bounds.push(macro_try!(parse2(quote! { ::rpc_toolkit::Context }))); + if let Options::Parent(ParentOptions { subcommands, .. }) = opt { + bounds.push(macro_try!(parse2(quote! { Clone }))); + for subcmd in subcommands { + let mut path = subcmd.clone(); + std::mem::take(&mut path.segments.last_mut().unwrap().arguments); + bounds.push(macro_try!(parse2(quote! { #path::CommandContext }))); + } + } + quote! { + pub trait CommandContext: #bounds {} + impl CommandContext for T where T: #bounds {} + } +} + fn metadata(full_options: &Options) -> TokenStream { let options = match full_options { Options::Leaf(a) => a, @@ -395,8 +413,18 @@ fn rpc_handler( opt: &Options, params: &[ParamType], ) -> TokenStream { + let mut parent_data_ty = quote! { () }; + let mut generics = fn_generics.clone(); + generics.params.push(macro_try!(syn::parse2( + quote! { GenericContext: CommandContext } + ))); + if generics.lt_token.is_none() { + generics.lt_token = Some(Default::default()); + } + if generics.gt_token.is_none() { + generics.gt_token = Some(Default::default()); + } let mut param_def = Vec::new(); - let mut ctx_ty = quote! { () }; for param in params { match param { ParamType::Arg(arg) => { @@ -412,9 +440,7 @@ fn rpc_handler( #field_name: #ty, }) } - ParamType::Context(ctx) => { - ctx_ty = quote! { #ctx }; - } + ParamType::ParentData(ty) => parent_data_ty = quote! { #ty }, _ => (), } } @@ -447,7 +473,16 @@ fn rpc_handler( let field_name = Ident::new(&format!("arg_{}", name), name.span()); quote! { args.#field_name } } - ParamType::Context(_) => quote! { ctx }, + ParamType::Context(ty) => { + if matches!(opt, Options::Parent { .. }) { + quote! { >::into(ctx.clone()) } + } else { + quote! { >::into(ctx) } + } + } + ParamType::ParentData(_) => { + quote! { parent_data } + } ParamType::Request => quote! { request }, ParamType::Response => quote! { response }, ParamType::None => unreachable!(), @@ -456,8 +491,9 @@ fn rpc_handler( Options::Leaf(opt) if matches!(opt.exec_ctx, ExecutionContext::CliOnly(_)) => quote! { #param_struct_def - pub async fn rpc_handler#fn_generics( - _ctx: #ctx_ty, + pub async fn rpc_handler#generics( + _ctx: GenericContext, + _parent_data: #parent_data_ty, _request: &::rpc_toolkit::command_helpers::prelude::RequestParts, _response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts, method: &str, @@ -486,8 +522,9 @@ fn rpc_handler( quote! { #param_struct_def - pub async fn rpc_handler#fn_generics( - ctx: #ctx_ty, + pub async fn rpc_handler#generics( + ctx: GenericContext, + parent_data: #parent_data_ty, request: &::rpc_toolkit::command_helpers::prelude::RequestParts, response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts, method: &str, @@ -511,28 +548,39 @@ fn rpc_handler( }) => { let cmd_preprocess = if common.is_async { quote! { - let ctx = #fn_path(#(#param),*).await?; + let parent_data = #fn_path(#(#param),*).await?; } } else if common.blocking.is_some() { quote! { - let ctx = ::rpc_toolkit::command_helpers::prelude::spawn_blocking(move || #fn_path(#(#param),*)).await?; + let parent_data = ::rpc_toolkit::command_helpers::prelude::spawn_blocking(move || #fn_path(#(#param),*)).await?; } } else { quote! { - let ctx = #fn_path(#(#param),*)?; + let parent_data = #fn_path(#(#param),*)?; } }; let subcmd_impl = subcommands.iter().map(|subcommand| { let mut subcommand = subcommand.clone(); - let rpc_handler = PathSegment { + let mut rpc_handler = PathSegment { ident: Ident::new("rpc_handler", Span::call_site()), arguments: std::mem::replace( &mut subcommand.segments.last_mut().unwrap().arguments, PathArguments::None, ), }; + rpc_handler.arguments = match rpc_handler.arguments { + PathArguments::None => PathArguments::AngleBracketed( + syn::parse2(quote! { :: }) + .unwrap(), + ), + PathArguments::AngleBracketed(mut a) => { + a.args.push(syn::parse2(quote! { GenericContext }).unwrap()); + PathArguments::AngleBracketed(a) + } + _ => unreachable!(), + }; quote_spanned!{ subcommand.span() => - [#subcommand::NAME, rest] => #subcommand::#rpc_handler(ctx, request, response, rest, ::rpc_toolkit::command_helpers::prelude::from_value(args.rest)?).await + [#subcommand::NAME, rest] => #subcommand::#rpc_handler(ctx, parent_data, request, response, rest, ::rpc_toolkit::command_helpers::prelude::from_value(args.rest)?).await } }); let subcmd_impl = quote! { @@ -551,22 +599,26 @@ fn rpc_handler( let self_impl_fn = &self_impl.path; let self_impl = if self_impl.is_async { quote_spanned! { self_impl_fn.span() => - #self_impl_fn(ctx).await? + #self_impl_fn(Into::into(ctx), parent_data).await? } } else if self_impl.blocking { quote_spanned! { self_impl_fn.span() => - ::rpc_toolkit::command_helpers::prelude::spawn_blocking(move || #self_impl_fn(ctx)).await? + { + let ctx = Into::into(ctx); + ::rpc_toolkit::command_helpers::prelude::spawn_blocking(move || #self_impl_fn(ctx, parent_data)).await? + } } } else { quote_spanned! { self_impl_fn.span() => - #self_impl_fn(ctx)? + #self_impl_fn(Into::into(ctx), parent_data)? } }; quote! { #param_struct_def - pub async fn rpc_handler#fn_generics( - ctx: #ctx_ty, + pub async fn rpc_handler#generics( + ctx: GenericContext, + parent_data: #parent_data_ty, request: &::rpc_toolkit::command_helpers::prelude::RequestParts, response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts, method: &str, @@ -586,8 +638,9 @@ fn rpc_handler( quote! { #param_struct_def - pub async fn rpc_handler#fn_generics( - ctx: #ctx_ty, + pub async fn rpc_handler#generics( + ctx: GenericContext, + parent_data: #parent_data_ty, request: &::rpc_toolkit::command_helpers::prelude::RequestParts, response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts, method: &str, @@ -610,19 +663,14 @@ fn cli_handler( opt: &mut Options, params: &[ParamType], ) -> TokenStream { - let mut ctx_ty = quote! { () }; - for param in params { - match param { - ParamType::Context(ctx) => { - ctx_ty = quote! { #ctx }; - } - _ => (), - } - } + let mut parent_data_ty = quote! { () }; let mut generics = fn_generics.clone(); generics.params.push(macro_try!(syn::parse2( quote! { ParentParams: ::rpc_toolkit::command_helpers::prelude::Serialize } ))); + generics.params.push(macro_try!(syn::parse2( + quote! { GenericContext: CommandContext } + ))); if generics.lt_token.is_none() { generics.lt_token = Some(Default::default()); } @@ -632,13 +680,24 @@ fn cli_handler( let (_, fn_type_generics, _) = fn_generics.split_for_impl(); let fn_turbofish = fn_type_generics.as_turbofish(); let fn_path: Path = macro_try!(syn::parse2(quote! { super::#fn_name#fn_turbofish })); + let is_parent = matches!(opt, Options::Parent { .. }); let param = params.iter().map(|param| match param { ParamType::Arg(arg) => { let name = arg.name.clone().unwrap(); let field_name = Ident::new(&format!("arg_{}", name), name.span()); quote! { params.#field_name.clone() } } - ParamType::Context(_) => quote! { ctx }, + ParamType::Context(ty) => { + if is_parent { + quote! { >::into(ctx.clone()) } + } else { + quote! { >::into(ctx) } + } + } + ParamType::ParentData(ty) => { + parent_data_ty = quote! { #ty }; + quote! { parent_data } + } ParamType::Request => quote! { request }, ParamType::Response => quote! { response }, ParamType::None => unreachable!(), @@ -654,10 +713,10 @@ fn cli_handler( ParentParams: ::rpc_toolkit::command_helpers::prelude::Serialize }))); if param_generics.lt_token.is_none() { - generics.lt_token = Some(Default::default()); + param_generics.lt_token = Some(Default::default()); } if param_generics.gt_token.is_none() { - generics.gt_token = Some(Default::default()); + param_generics.gt_token = Some(Default::default()); } let (_, param_ty_generics, _) = param_generics.split_for_impl(); let mut arg_def = Vec::new(); @@ -777,7 +836,8 @@ fn cli_handler( match opt { Options::Leaf(opt) if matches!(opt.exec_ctx, ExecutionContext::RpcOnly(_)) => quote! { pub fn cli_handler#generics( - _ctx: #ctx_ty, + _ctx: (), + _parent_data: #parent_data_ty, _rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>, _matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>, method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>, @@ -802,7 +862,8 @@ fn cli_handler( }; quote! { pub fn cli_handler#generics( - ctx: #ctx_ty, + ctx: GenericContext, + parent_data: #parent_data_ty, mut rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>, matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>, method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>, @@ -816,6 +877,9 @@ fn cli_handler( let return_ty = if true { ::rpc_toolkit::command_helpers::prelude::PhantomData } else { + let ctx_new = unreachable!(); + ::rpc_toolkit::command_helpers::prelude::match_types(&ctx, &ctx_new); + let ctx = ctx_new; ::rpc_toolkit::command_helpers::prelude::make_phantom(#invocation) }; @@ -830,13 +894,25 @@ fn cli_handler( } = opt.exec_ctx { let fn_path = cli; + let cli_param = params.iter().filter_map(|param| match param { + ParamType::Arg(arg) => { + let name = arg.name.clone().unwrap(); + let field_name = Ident::new(&format!("arg_{}", name), name.span()); + Some(quote! { params.#field_name.clone() }) + } + ParamType::Context(_) => Some(quote! { Into::into(ctx) }), + ParamType::ParentData(_) => Some(quote! { parent_data }), + ParamType::Request => None, + ParamType::Response => None, + ParamType::None => unreachable!(), + }); let invocation = if is_async { quote! { - rt_ref.block_on(#fn_path(#(#param),*))? + rt_ref.block_on(#fn_path(#(#cli_param),*))? } } else { quote! { - #fn_path(#(#param),*)? + #fn_path(#(#cli_param),*)? } }; let display_res = if let Some(display_fn) = &opt.display { @@ -857,7 +933,8 @@ fn cli_handler( }; quote! { pub fn cli_handler#generics( - ctx: #ctx_ty, + ctx: GenericContext, + parent_data: #parent_data_ty, mut rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>, matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>, _method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>, @@ -898,7 +975,8 @@ fn cli_handler( }; quote! { pub fn cli_handler#generics( - ctx: #ctx_ty, + ctx: GenericContext, + parent_data: #parent_data_ty, mut rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>, matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>, _method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>, @@ -921,11 +999,11 @@ fn cli_handler( let cmd_preprocess = if common.is_async { quote! { #create_rt - let ctx = rt_ref.block_on(#fn_path(#(#param),*))?; + let parent_data = rt_ref.block_on(#fn_path(#(#param),*))?; } } else { quote! { - let ctx = #fn_path(#(#param),*)?; + let parent_data = #fn_path(#(#param),*)?; } }; let subcmd_impl = subcommands.iter().map(|subcommand| { @@ -939,11 +1017,13 @@ fn cli_handler( }; cli_handler.arguments = match cli_handler.arguments { PathArguments::None => PathArguments::AngleBracketed( - syn::parse2(quote! { :: }).unwrap(), + syn::parse2(quote! { :: }) + .unwrap(), ), PathArguments::AngleBracketed(mut a) => { a.args .push(syn::parse2(quote! { Params#param_ty_generics }).unwrap()); + a.args.push(syn::parse2(quote! { GenericContext }).unwrap()); PathArguments::AngleBracketed(a) } _ => unreachable!(), @@ -955,26 +1035,34 @@ fn cli_handler( } else { method + "." + #subcommand::NAME }; - #subcommand::#cli_handler(ctx, rt, sub_m, method, params) + #subcommand::#cli_handler(ctx, parent_data, rt, sub_m, method, params) }, } }); let self_impl = match (self_impl, &common.exec_ctx) { - (Some(self_impl), ExecutionContext::CliOnly(_)) => { - let self_impl_fn = &self_impl.path; + (Some(self_impl), ExecutionContext::CliOnly(_)) + | (Some(self_impl), ExecutionContext::Local(_)) + | (Some(self_impl), ExecutionContext::CustomCli { .. }) => { + let (self_impl_fn, is_async) = + if let ExecutionContext::CustomCli { cli, is_async, .. } = &common.exec_ctx + { + (cli, *is_async) + } else { + (&self_impl.path, self_impl.is_async) + }; let create_rt = if common.is_async { None } else { Some(create_rt) }; - let self_impl = if self_impl.is_async { + let self_impl = if is_async { quote_spanned! { self_impl_fn.span() => #create_rt - rt_ref.block_on(#self_impl_fn(ctx))? + rt_ref.block_on(#self_impl_fn(Into::into(ctx), parent_data))? } } else { quote_spanned! { self_impl_fn.span() => - #self_impl_fn(ctx)? + #self_impl_fn(Into::into(ctx), parent_data)? } }; quote! { @@ -985,11 +1073,11 @@ fn cli_handler( let self_impl_fn = &self_impl.path; let self_impl = if self_impl.is_async { quote! { - rt_ref.block_on(#self_impl_fn(ctx)) + rt_ref.block_on(#self_impl_fn(Into::into(ctx), parent_data)) } } else { quote! { - #self_impl_fn(ctx) + #self_impl_fn(Into::into(ctx), parent_data) } }; let create_rt = if common.is_async { @@ -1016,7 +1104,7 @@ fn cli_handler( } } } - _ => quote! { + (None, _) | (Some(_), ExecutionContext::RpcOnly(_)) => quote! { Err(::rpc_toolkit::command_helpers::prelude::RpcError { data: Some(method.into()), ..::rpc_toolkit::command_helpers::prelude::yajrc::METHOD_NOT_FOUND_ERROR @@ -1025,7 +1113,8 @@ fn cli_handler( }; quote! { pub fn cli_handler#generics( - ctx: #ctx_ty, + ctx: GenericContext, + parent_data: #parent_data_ty, mut rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>, matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>, method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>, @@ -1071,6 +1160,17 @@ pub fn build(args: AttributeArgs, mut item: ItemFn) -> TokenStream { .map(|a| a.span()) .unwrap_or_else(Span::call_site), ); + let ctx_ty = params + .iter() + .find_map(|a| { + if let ParamType::Context(ty) = a { + Some(ty.clone()) + } else { + None + } + }) + .unwrap_or(macro_try!(syn::parse2(quote! { () }))); + let ctx_trait = ctx_trait(ctx_ty, &opt); let metadata = metadata(&mut opt); let build_app = build_app(command_name_str.clone(), &mut opt, &mut params); let rpc_handler = rpc_handler(fn_name, fn_generics, &opt, ¶ms); @@ -1084,6 +1184,8 @@ pub fn build(args: AttributeArgs, mut item: ItemFn) -> TokenStream { pub const NAME: &'static str = #command_name_str; pub const ASYNC: bool = #is_async; + #ctx_trait + #metadata #build_app diff --git a/rpc-toolkit-macro-internals/src/command/mod.rs b/rpc-toolkit-macro-internals/src/command/mod.rs index 9b221a6..0b77805 100644 --- a/rpc-toolkit-macro-internals/src/command/mod.rs +++ b/rpc-toolkit-macro-internals/src/command/mod.rs @@ -96,6 +96,7 @@ pub enum ParamType { None, Arg(ArgOptions), Context(Type), + ParentData(Type), Request, Response, } diff --git a/rpc-toolkit-macro-internals/src/command/parse.rs b/rpc-toolkit-macro-internals/src/command/parse.rs index 0108a3d..fdd5796 100644 --- a/rpc-toolkit-macro-internals/src/command/parse.rs +++ b/rpc-toolkit-macro-internals/src/command/parse.rs @@ -774,6 +774,11 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result> { attr.span(), "`arg` and `context` are mutually exclusive", )); + } else if matches!(ty, ParamType::ParentData(_)) { + return Err(Error::new( + attr.span(), + "`arg` and `parent_data` are mutually exclusive", + )); } else if matches!(ty, ParamType::Request) { return Err(Error::new( attr.span(), @@ -799,6 +804,11 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result> { attr.span(), "`arg` and `context` are mutually exclusive", )); + } else if matches!(ty, ParamType::ParentData(_)) { + return Err(Error::new( + attr.span(), + "`context` and `parent_data` are mutually exclusive", + )); } else if matches!(ty, ParamType::Request) { return Err(Error::new( attr.span(), @@ -810,6 +820,36 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result> { "`context` and `response` are mutually exclusive", )); } + } else if param.attrs[i].path.is_ident("parent_data") { + let attr = param.attrs.remove(i); + if matches!(ty, ParamType::None) { + ty = ParamType::ParentData(*param.ty.clone()); + } else if matches!(ty, ParamType::ParentData(_)) { + return Err(Error::new( + attr.span(), + "`parent_data` attribute may only be specified once", + )); + } else if matches!(ty, ParamType::Arg(_)) { + return Err(Error::new( + attr.span(), + "`arg` and `parent_data` are mutually exclusive", + )); + } else if matches!(ty, ParamType::Context(_)) { + return Err(Error::new( + attr.span(), + "`context` and `parent_data` are mutually exclusive", + )); + } else if matches!(ty, ParamType::Request) { + return Err(Error::new( + attr.span(), + "`parent_data` and `request` are mutually exclusive", + )); + } else if matches!(ty, ParamType::Response) { + return Err(Error::new( + attr.span(), + "`parent_data` and `response` are mutually exclusive", + )); + } } else if param.attrs[i].path.is_ident("request") { let attr = param.attrs.remove(i); if matches!(ty, ParamType::None) { @@ -829,6 +869,11 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result> { attr.span(), "`context` and `request` are mutually exclusive", )); + } else if matches!(ty, ParamType::ParentData(_)) { + return Err(Error::new( + attr.span(), + "`parent_data` and `request` are mutually exclusive", + )); } else if matches!(ty, ParamType::Response) { return Err(Error::new( attr.span(), @@ -854,6 +899,11 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result> { attr.span(), "`context` and `response` are mutually exclusive", )); + } else if matches!(ty, ParamType::Context(_)) { + return Err(Error::new( + attr.span(), + "`parent_data` and `response` are mutually exclusive", + )); } else if matches!(ty, ParamType::Request) { return Err(Error::new( attr.span(), @@ -867,7 +917,7 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result> { if matches!(ty, ParamType::None) { return Err(Error::new( param.span(), - "must specify either `arg` or `context` attributes", + "must specify either `arg`, `context`, `parent_data`, `request`, or `response` attributes", )); } params.push(ty) diff --git a/rpc-toolkit-macro-internals/src/rpc_server/build.rs b/rpc-toolkit-macro-internals/src/rpc_server/build.rs index 7d9af77..f2a4ae0 100644 --- a/rpc-toolkit-macro-internals/src/rpc_server/build.rs +++ b/rpc-toolkit-macro-internals/src/rpc_server/build.rs @@ -6,16 +6,24 @@ use super::*; pub fn build(args: RpcServerArgs) -> TokenStream { let mut command = args.command; - let arguments = std::mem::replace( + let mut arguments = std::mem::replace( &mut command.segments.last_mut().unwrap().arguments, PathArguments::None, ); let command_module = command.clone(); + if let PathArguments::AngleBracketed(a) = &mut arguments { + a.args.push(syn::parse2(quote! { _ }).unwrap()); + } command.segments.push(PathSegment { ident: Ident::new("rpc_handler", command.span()), arguments, }); let ctx = args.ctx; + let parent_data = if let Some(data) = args.parent_data { + quote! { #data } + } else { + quote! { () } + }; let status_fn = args.status_fn.unwrap_or_else(|| { syn::parse2(quote! { |_| ::rpc_toolkit::hyper::StatusCode::OK }).unwrap() }); @@ -47,6 +55,7 @@ pub fn build(args: RpcServerArgs) -> TokenStream { let res = quote! { { let ctx = #ctx; + let parent_data = #parent_data; let status_fn = #status_fn; let builder = ::rpc_toolkit::rpc_server_helpers::make_builder(&ctx); #( @@ -54,12 +63,14 @@ pub fn build(args: RpcServerArgs) -> TokenStream { )* let make_svc = ::rpc_toolkit::hyper::service::make_service_fn(move |_| { let ctx = ctx.clone(); + let parent_data = parent_data.clone(); #( let #middleware_name_clone3 = #middleware_name_clone2.clone(); )* async move { Ok::<_, ::rpc_toolkit::hyper::Error>(::rpc_toolkit::hyper::service::service_fn(move |mut req| { let ctx = ctx.clone(); + let parent_data = parent_data.clone(); let metadata = #command_module::Metadata::default(); #( let #middleware_name_clone5 = #middleware_name_clone4.clone(); @@ -83,7 +94,7 @@ pub fn build(args: RpcServerArgs) -> TokenStream { }; )* let mut rpc_res = match ::rpc_toolkit::serde_json::from_value(::rpc_toolkit::serde_json::Value::Object(rpc_req.params)) { - Ok(params) => #command(ctx, &req_parts, &mut res_parts, ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), params).await, + Ok(params) => #command(ctx, parent_data, &req_parts, &mut res_parts, ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), params).await, Err(e) => Err(e.into()) }; #( diff --git a/rpc-toolkit-macro-internals/src/rpc_server/mod.rs b/rpc-toolkit-macro-internals/src/rpc_server/mod.rs index 4e8f703..7016ce6 100644 --- a/rpc-toolkit-macro-internals/src/rpc_server/mod.rs +++ b/rpc-toolkit-macro-internals/src/rpc_server/mod.rs @@ -3,6 +3,7 @@ use syn::*; pub struct RpcServerArgs { command: Path, ctx: Expr, + parent_data: Option, status_fn: Option, middleware: punctuated::Punctuated, } diff --git a/rpc-toolkit-macro-internals/src/rpc_server/parse.rs b/rpc-toolkit-macro-internals/src/rpc_server/parse.rs index 3d03afd..4530f6e 100644 --- a/rpc-toolkit-macro-internals/src/rpc_server/parse.rs +++ b/rpc-toolkit-macro-internals/src/rpc_server/parse.rs @@ -9,6 +9,7 @@ impl Parse for RpcServerArgs { braced!(args in input); let mut command = None; let mut ctx = None; + let mut parent_data = None; let mut status_fn = None; let mut middleware = Punctuated::new(); while !args.is_empty() { @@ -21,6 +22,9 @@ impl Parse for RpcServerArgs { "context" => { ctx = Some(args.parse()?); } + "parent_data" => { + parent_data = Some(args.parse()?); + } "status" => { status_fn = Some(args.parse()?); } @@ -38,6 +42,7 @@ impl Parse for RpcServerArgs { Ok(RpcServerArgs { command: command.expect("`command` is required"), ctx: ctx.expect("`context` is required"), + parent_data, status_fn, middleware, }) diff --git a/rpc-toolkit-macro-internals/src/run_cli/build.rs b/rpc-toolkit-macro-internals/src/run_cli/build.rs index 3278507..bbf56af 100644 --- a/rpc-toolkit-macro-internals/src/run_cli/build.rs +++ b/rpc-toolkit-macro-internals/src/run_cli/build.rs @@ -13,6 +13,7 @@ pub fn build(args: RunCliArgs) -> TokenStream { let command = command_handler.clone(); if let PathArguments::AngleBracketed(a) = &mut arguments { a.args.push(syn::parse2(quote! { () }).unwrap()); + a.args.push(syn::parse2(quote! { _ }).unwrap()); } command_handler.segments.push(PathSegment { ident: Ident::new("cli_handler", command.span()), @@ -42,6 +43,11 @@ pub fn build(args: RunCliArgs) -> TokenStream { } else { quote! { &rpc_toolkit_matches } }; + let parent_data = if let Some(data) = args.parent_data { + quote! { #data } + } else { + quote! { () } + }; let exit_fn = args.exit_fn.unwrap_or_else(|| { syn::parse2(quote! { |err: ::rpc_toolkit::yajrc::RpcError| { eprintln!("{}", err.message); @@ -56,8 +62,10 @@ pub fn build(args: RunCliArgs) -> TokenStream { { let rpc_toolkit_matches = #app.get_matches(); let rpc_toolkit_ctx = #make_ctx; + let rpc_toolkit_parent_data = #parent_data; if let Err(err) = #command_handler( rpc_toolkit_ctx, + rpc_toolkit_parent_data, None, &rpc_toolkit_matches, "".into(), diff --git a/rpc-toolkit-macro-internals/src/run_cli/mod.rs b/rpc-toolkit-macro-internals/src/run_cli/mod.rs index 91ae048..780ae90 100644 --- a/rpc-toolkit-macro-internals/src/run_cli/mod.rs +++ b/rpc-toolkit-macro-internals/src/run_cli/mod.rs @@ -14,6 +14,7 @@ pub struct RunCliArgs { command: Path, mut_app: Option, make_ctx: Option, + parent_data: Option, exit_fn: Option, } diff --git a/rpc-toolkit-macro-internals/src/run_cli/parse.rs b/rpc-toolkit-macro-internals/src/run_cli/parse.rs index 82bd770..0f44c05 100644 --- a/rpc-toolkit-macro-internals/src/run_cli/parse.rs +++ b/rpc-toolkit-macro-internals/src/run_cli/parse.rs @@ -45,6 +45,14 @@ impl Parse for RunCliArgs { if !input.is_empty() { let _: token::Comma = input.parse()?; } + if !input.is_empty() { + let _: token::Comma = input.parse()?; + } + let parent_data = if !input.is_empty() { + Some(input.parse()?) + } else { + None + }; let exit_fn = if !input.is_empty() { Some(input.parse()?) } else { @@ -54,6 +62,7 @@ impl Parse for RunCliArgs { command, mut_app, make_ctx, + parent_data, exit_fn, }) } diff --git a/rpc-toolkit/src/context.rs b/rpc-toolkit/src/context.rs index e1d5ecf..1cae28e 100644 --- a/rpc-toolkit/src/context.rs +++ b/rpc-toolkit/src/context.rs @@ -33,3 +33,9 @@ pub trait Context { } impl Context for () {} + +impl<'a, T: Context + 'a> From for Box { + fn from(ctx: T) -> Self { + Box::new(ctx) + } +} diff --git a/rpc-toolkit/tests/test.rs b/rpc-toolkit/tests/test.rs index 8be9688..4ac992b 100644 --- a/rpc-toolkit/tests/test.rs +++ b/rpc-toolkit/tests/test.rs @@ -8,8 +8,7 @@ use rpc_toolkit::clap::Arg; use rpc_toolkit::hyper::http::Error as HttpError; use rpc_toolkit::hyper::{Body, Response}; use rpc_toolkit::rpc_server_helpers::{ - constrain_middleware, DynMiddleware, DynMiddlewareStage2, DynMiddlewareStage3, - DynMiddlewareStage4, + DynMiddlewareStage2, DynMiddlewareStage3, DynMiddlewareStage4, }; use rpc_toolkit::serde::{Deserialize, Serialize}; use rpc_toolkit::url::Host; @@ -17,34 +16,29 @@ use rpc_toolkit::yajrc::RpcError; use rpc_toolkit::{command, rpc_server, run_cli, Context, Metadata}; #[derive(Debug, Clone)] -pub struct AppState { - seed: T, - data: U, -} -impl AppState { - pub fn map V, V>(self, f: F) -> AppState { - AppState { - seed: self.seed, - data: f(self.data), - } +pub struct AppState(Arc); +impl From for () { + fn from(_: AppState) -> Self { + () } } +#[derive(Debug)] pub struct ConfigSeed { host: Host, port: u16, } -impl Context for AppState, T> { +impl Context for AppState { fn host(&self) -> Host<&str> { - match &self.seed.host { + match &self.0.host { Host::Domain(s) => Host::Domain(s.as_str()), Host::Ipv4(i) => Host::Ipv4(*i), Host::Ipv6(i) => Host::Ipv6(*i), } } fn port(&self) -> u16 { - self.seed.port + self.0.port } } @@ -53,41 +47,43 @@ impl Context for AppState, T> { subcommands("dothething2::", self(dothething_impl(async))) )] async fn dothething< - U: Serialize + for<'a> Deserialize<'a> + FromStr + Clone, + U: Serialize + for<'a> Deserialize<'a> + FromStr + Clone + 'static, E: Display, >( - #[context] ctx: AppState, ()>, + #[context] _ctx: AppState, #[arg(short = "a")] arg1: Option, #[arg(short = "b")] val: String, #[arg(short = "c", help = "I am the flag `c`!")] arg3: bool, #[arg(stdin)] structured: U, -) -> Result, (Option, String, bool, U)>, RpcError> { - Ok(ctx.map(|_| (arg1, val, arg3, structured))) +) -> Result<(Option, String, bool, U), RpcError> { + Ok((arg1, val, arg3, structured)) } async fn dothething_impl( - ctx: AppState, (Option, String, bool, U)>, + ctx: AppState, + parent_data: (Option, String, bool, U), ) -> Result { Ok(format!( - "{:?}, {}, {}, {}", - ctx.data.0, - ctx.data.1, - ctx.data.2, - serde_json::to_string_pretty(&ctx.data.3)? + "{:?}, {:?}, {}, {}, {}", + ctx, + parent_data.0, + parent_data.1, + parent_data.2, + serde_json::to_string_pretty(&parent_data.3)? )) } #[command(about = "Does the thing")] fn dothething2 Deserialize<'a> + FromStr, E: Display>( - #[context] ctx: AppState, (Option, String, bool, U)>, + #[parent_data] parent_data: (Option, String, bool, U), #[arg(stdin)] structured2: U, ) -> Result { Ok(format!( "{:?}, {}, {}, {}, {}", - ctx.data.0, - ctx.data.1, - ctx.data.2, - serde_json::to_string_pretty(&ctx.data.3)?, + parent_data.0, + parent_data.1, + parent_data.2, + serde_json::to_string_pretty(&parent_data.3)?, serde_json::to_string_pretty(&structured2)?, )) } @@ -134,7 +130,7 @@ async fn test_rpc() { }); let server = rpc_server!({ command: dothething::, - context: AppState { seed, data: () }, + context: AppState(seed), middleware: [ cors, ], @@ -192,14 +188,8 @@ fn cli_test() { host: Host::parse("localhost").unwrap(), port: 8000, }); - dothething::cli_handler::( - AppState { seed, data: () }, - None, - &matches, - "".into(), - (), - ) - .unwrap(); + dothething::cli_handler::(AppState(seed), (), None, &matches, "".into(), ()) + .unwrap(); } #[test] @@ -210,14 +200,11 @@ fn cli_example() { app => app .arg(Arg::with_name("host").long("host").short("h").takes_value(true)) .arg(Arg::with_name("port").long("port").short("p").takes_value(true)), - matches => AppState { seed: Arc::new(ConfigSeed { + matches => AppState(Arc::new(ConfigSeed { host: Host::parse(matches.value_of("host").unwrap_or("localhost")).unwrap(), port: matches.value_of("port").unwrap_or("8000").parse().unwrap(), - }), data: () } + })) ) } -fn type_check() { - let middleware: DynMiddleware = todo!(); - constrain_middleware(&middleware); -} +////////////////////////////////////////////////