overhaul context

This commit is contained in:
Aiden McClelland
2021-08-31 19:16:40 -06:00
parent ee9b302ecd
commit 3195ffab68
11 changed files with 282 additions and 101 deletions

View File

@@ -5,11 +5,29 @@ use quote::*;
use syn::fold::Fold; use syn::fold::Fold;
use syn::punctuated::Punctuated; use syn::punctuated::Punctuated;
use syn::spanned::Spanned; use syn::spanned::Spanned;
use syn::token::{Comma, Where}; use syn::token::{Add, Comma, Where};
use super::parse::*; use super::parse::*;
use super::*; use super::*;
fn ctx_trait(ctx_ty: Type, opt: &Options) -> TokenStream {
let mut bounds: Punctuated<TypeParamBound, Add> = Punctuated::new();
bounds.push(macro_try!(parse2(quote! { Into<#ctx_ty> })));
bounds.push(macro_try!(parse2(quote! { ::rpc_toolkit::Context })));
if let Options::Parent(ParentOptions { subcommands, .. }) = opt {
bounds.push(macro_try!(parse2(quote! { Clone })));
for subcmd in subcommands {
let mut path = subcmd.clone();
std::mem::take(&mut path.segments.last_mut().unwrap().arguments);
bounds.push(macro_try!(parse2(quote! { #path::CommandContext })));
}
}
quote! {
pub trait CommandContext: #bounds {}
impl<T> CommandContext for T where T: #bounds {}
}
}
fn metadata(full_options: &Options) -> TokenStream { fn metadata(full_options: &Options) -> TokenStream {
let options = match full_options { let options = match full_options {
Options::Leaf(a) => a, Options::Leaf(a) => a,
@@ -395,8 +413,18 @@ fn rpc_handler(
opt: &Options, opt: &Options,
params: &[ParamType], params: &[ParamType],
) -> TokenStream { ) -> TokenStream {
let mut parent_data_ty = quote! { () };
let mut generics = fn_generics.clone();
generics.params.push(macro_try!(syn::parse2(
quote! { GenericContext: CommandContext }
)));
if generics.lt_token.is_none() {
generics.lt_token = Some(Default::default());
}
if generics.gt_token.is_none() {
generics.gt_token = Some(Default::default());
}
let mut param_def = Vec::new(); let mut param_def = Vec::new();
let mut ctx_ty = quote! { () };
for param in params { for param in params {
match param { match param {
ParamType::Arg(arg) => { ParamType::Arg(arg) => {
@@ -412,9 +440,7 @@ fn rpc_handler(
#field_name: #ty, #field_name: #ty,
}) })
} }
ParamType::Context(ctx) => { ParamType::ParentData(ty) => parent_data_ty = quote! { #ty },
ctx_ty = quote! { #ctx };
}
_ => (), _ => (),
} }
} }
@@ -447,7 +473,16 @@ fn rpc_handler(
let field_name = Ident::new(&format!("arg_{}", name), name.span()); let field_name = Ident::new(&format!("arg_{}", name), name.span());
quote! { args.#field_name } quote! { args.#field_name }
} }
ParamType::Context(_) => quote! { ctx }, ParamType::Context(ty) => {
if matches!(opt, Options::Parent { .. }) {
quote! { <GenericContext as Into<#ty>>::into(ctx.clone()) }
} else {
quote! { <GenericContext as Into<#ty>>::into(ctx) }
}
}
ParamType::ParentData(_) => {
quote! { parent_data }
}
ParamType::Request => quote! { request }, ParamType::Request => quote! { request },
ParamType::Response => quote! { response }, ParamType::Response => quote! { response },
ParamType::None => unreachable!(), ParamType::None => unreachable!(),
@@ -456,8 +491,9 @@ fn rpc_handler(
Options::Leaf(opt) if matches!(opt.exec_ctx, ExecutionContext::CliOnly(_)) => quote! { Options::Leaf(opt) if matches!(opt.exec_ctx, ExecutionContext::CliOnly(_)) => quote! {
#param_struct_def #param_struct_def
pub async fn rpc_handler#fn_generics( pub async fn rpc_handler#generics(
_ctx: #ctx_ty, _ctx: GenericContext,
_parent_data: #parent_data_ty,
_request: &::rpc_toolkit::command_helpers::prelude::RequestParts, _request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
_response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts, _response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
method: &str, method: &str,
@@ -486,8 +522,9 @@ fn rpc_handler(
quote! { quote! {
#param_struct_def #param_struct_def
pub async fn rpc_handler#fn_generics( pub async fn rpc_handler#generics(
ctx: #ctx_ty, ctx: GenericContext,
parent_data: #parent_data_ty,
request: &::rpc_toolkit::command_helpers::prelude::RequestParts, request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts, response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
method: &str, method: &str,
@@ -511,28 +548,39 @@ fn rpc_handler(
}) => { }) => {
let cmd_preprocess = if common.is_async { let cmd_preprocess = if common.is_async {
quote! { quote! {
let ctx = #fn_path(#(#param),*).await?; let parent_data = #fn_path(#(#param),*).await?;
} }
} else if common.blocking.is_some() { } else if common.blocking.is_some() {
quote! { quote! {
let ctx = ::rpc_toolkit::command_helpers::prelude::spawn_blocking(move || #fn_path(#(#param),*)).await?; let parent_data = ::rpc_toolkit::command_helpers::prelude::spawn_blocking(move || #fn_path(#(#param),*)).await?;
} }
} else { } else {
quote! { quote! {
let ctx = #fn_path(#(#param),*)?; let parent_data = #fn_path(#(#param),*)?;
} }
}; };
let subcmd_impl = subcommands.iter().map(|subcommand| { let subcmd_impl = subcommands.iter().map(|subcommand| {
let mut subcommand = subcommand.clone(); let mut subcommand = subcommand.clone();
let rpc_handler = PathSegment { let mut rpc_handler = PathSegment {
ident: Ident::new("rpc_handler", Span::call_site()), ident: Ident::new("rpc_handler", Span::call_site()),
arguments: std::mem::replace( arguments: std::mem::replace(
&mut subcommand.segments.last_mut().unwrap().arguments, &mut subcommand.segments.last_mut().unwrap().arguments,
PathArguments::None, PathArguments::None,
), ),
}; };
rpc_handler.arguments = match rpc_handler.arguments {
PathArguments::None => PathArguments::AngleBracketed(
syn::parse2(quote! { ::<GenericContext> })
.unwrap(),
),
PathArguments::AngleBracketed(mut a) => {
a.args.push(syn::parse2(quote! { GenericContext }).unwrap());
PathArguments::AngleBracketed(a)
}
_ => unreachable!(),
};
quote_spanned!{ subcommand.span() => quote_spanned!{ subcommand.span() =>
[#subcommand::NAME, rest] => #subcommand::#rpc_handler(ctx, request, response, rest, ::rpc_toolkit::command_helpers::prelude::from_value(args.rest)?).await [#subcommand::NAME, rest] => #subcommand::#rpc_handler(ctx, parent_data, request, response, rest, ::rpc_toolkit::command_helpers::prelude::from_value(args.rest)?).await
} }
}); });
let subcmd_impl = quote! { let subcmd_impl = quote! {
@@ -551,22 +599,26 @@ fn rpc_handler(
let self_impl_fn = &self_impl.path; let self_impl_fn = &self_impl.path;
let self_impl = if self_impl.is_async { let self_impl = if self_impl.is_async {
quote_spanned! { self_impl_fn.span() => quote_spanned! { self_impl_fn.span() =>
#self_impl_fn(ctx).await? #self_impl_fn(Into::into(ctx), parent_data).await?
} }
} else if self_impl.blocking { } else if self_impl.blocking {
quote_spanned! { self_impl_fn.span() => quote_spanned! { self_impl_fn.span() =>
::rpc_toolkit::command_helpers::prelude::spawn_blocking(move || #self_impl_fn(ctx)).await? {
let ctx = Into::into(ctx);
::rpc_toolkit::command_helpers::prelude::spawn_blocking(move || #self_impl_fn(ctx, parent_data)).await?
}
} }
} else { } else {
quote_spanned! { self_impl_fn.span() => quote_spanned! { self_impl_fn.span() =>
#self_impl_fn(ctx)? #self_impl_fn(Into::into(ctx), parent_data)?
} }
}; };
quote! { quote! {
#param_struct_def #param_struct_def
pub async fn rpc_handler#fn_generics( pub async fn rpc_handler#generics(
ctx: #ctx_ty, ctx: GenericContext,
parent_data: #parent_data_ty,
request: &::rpc_toolkit::command_helpers::prelude::RequestParts, request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts, response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
method: &str, method: &str,
@@ -586,8 +638,9 @@ fn rpc_handler(
quote! { quote! {
#param_struct_def #param_struct_def
pub async fn rpc_handler#fn_generics( pub async fn rpc_handler#generics(
ctx: #ctx_ty, ctx: GenericContext,
parent_data: #parent_data_ty,
request: &::rpc_toolkit::command_helpers::prelude::RequestParts, request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts, response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
method: &str, method: &str,
@@ -610,19 +663,14 @@ fn cli_handler(
opt: &mut Options, opt: &mut Options,
params: &[ParamType], params: &[ParamType],
) -> TokenStream { ) -> TokenStream {
let mut ctx_ty = quote! { () }; let mut parent_data_ty = quote! { () };
for param in params {
match param {
ParamType::Context(ctx) => {
ctx_ty = quote! { #ctx };
}
_ => (),
}
}
let mut generics = fn_generics.clone(); let mut generics = fn_generics.clone();
generics.params.push(macro_try!(syn::parse2( generics.params.push(macro_try!(syn::parse2(
quote! { ParentParams: ::rpc_toolkit::command_helpers::prelude::Serialize } quote! { ParentParams: ::rpc_toolkit::command_helpers::prelude::Serialize }
))); )));
generics.params.push(macro_try!(syn::parse2(
quote! { GenericContext: CommandContext }
)));
if generics.lt_token.is_none() { if generics.lt_token.is_none() {
generics.lt_token = Some(Default::default()); generics.lt_token = Some(Default::default());
} }
@@ -632,13 +680,24 @@ fn cli_handler(
let (_, fn_type_generics, _) = fn_generics.split_for_impl(); let (_, fn_type_generics, _) = fn_generics.split_for_impl();
let fn_turbofish = fn_type_generics.as_turbofish(); let fn_turbofish = fn_type_generics.as_turbofish();
let fn_path: Path = macro_try!(syn::parse2(quote! { super::#fn_name#fn_turbofish })); let fn_path: Path = macro_try!(syn::parse2(quote! { super::#fn_name#fn_turbofish }));
let is_parent = matches!(opt, Options::Parent { .. });
let param = params.iter().map(|param| match param { let param = params.iter().map(|param| match param {
ParamType::Arg(arg) => { ParamType::Arg(arg) => {
let name = arg.name.clone().unwrap(); let name = arg.name.clone().unwrap();
let field_name = Ident::new(&format!("arg_{}", name), name.span()); let field_name = Ident::new(&format!("arg_{}", name), name.span());
quote! { params.#field_name.clone() } quote! { params.#field_name.clone() }
} }
ParamType::Context(_) => quote! { ctx }, ParamType::Context(ty) => {
if is_parent {
quote! { <GenericContext as Into<#ty>>::into(ctx.clone()) }
} else {
quote! { <GenericContext as Into<#ty>>::into(ctx) }
}
}
ParamType::ParentData(ty) => {
parent_data_ty = quote! { #ty };
quote! { parent_data }
}
ParamType::Request => quote! { request }, ParamType::Request => quote! { request },
ParamType::Response => quote! { response }, ParamType::Response => quote! { response },
ParamType::None => unreachable!(), ParamType::None => unreachable!(),
@@ -654,10 +713,10 @@ fn cli_handler(
ParentParams: ::rpc_toolkit::command_helpers::prelude::Serialize ParentParams: ::rpc_toolkit::command_helpers::prelude::Serialize
}))); })));
if param_generics.lt_token.is_none() { if param_generics.lt_token.is_none() {
generics.lt_token = Some(Default::default()); param_generics.lt_token = Some(Default::default());
} }
if param_generics.gt_token.is_none() { if param_generics.gt_token.is_none() {
generics.gt_token = Some(Default::default()); param_generics.gt_token = Some(Default::default());
} }
let (_, param_ty_generics, _) = param_generics.split_for_impl(); let (_, param_ty_generics, _) = param_generics.split_for_impl();
let mut arg_def = Vec::new(); let mut arg_def = Vec::new();
@@ -777,7 +836,8 @@ fn cli_handler(
match opt { match opt {
Options::Leaf(opt) if matches!(opt.exec_ctx, ExecutionContext::RpcOnly(_)) => quote! { Options::Leaf(opt) if matches!(opt.exec_ctx, ExecutionContext::RpcOnly(_)) => quote! {
pub fn cli_handler#generics( pub fn cli_handler#generics(
_ctx: #ctx_ty, _ctx: (),
_parent_data: #parent_data_ty,
_rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>, _rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>,
_matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>, _matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>,
method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>, method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>,
@@ -802,7 +862,8 @@ fn cli_handler(
}; };
quote! { quote! {
pub fn cli_handler#generics( pub fn cli_handler#generics(
ctx: #ctx_ty, ctx: GenericContext,
parent_data: #parent_data_ty,
mut rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>, mut rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>,
matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>, matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>,
method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>, method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>,
@@ -816,6 +877,9 @@ fn cli_handler(
let return_ty = if true { let return_ty = if true {
::rpc_toolkit::command_helpers::prelude::PhantomData ::rpc_toolkit::command_helpers::prelude::PhantomData
} else { } else {
let ctx_new = unreachable!();
::rpc_toolkit::command_helpers::prelude::match_types(&ctx, &ctx_new);
let ctx = ctx_new;
::rpc_toolkit::command_helpers::prelude::make_phantom(#invocation) ::rpc_toolkit::command_helpers::prelude::make_phantom(#invocation)
}; };
@@ -830,13 +894,25 @@ fn cli_handler(
} = opt.exec_ctx } = opt.exec_ctx
{ {
let fn_path = cli; let fn_path = cli;
let cli_param = params.iter().filter_map(|param| match param {
ParamType::Arg(arg) => {
let name = arg.name.clone().unwrap();
let field_name = Ident::new(&format!("arg_{}", name), name.span());
Some(quote! { params.#field_name.clone() })
}
ParamType::Context(_) => Some(quote! { Into::into(ctx) }),
ParamType::ParentData(_) => Some(quote! { parent_data }),
ParamType::Request => None,
ParamType::Response => None,
ParamType::None => unreachable!(),
});
let invocation = if is_async { let invocation = if is_async {
quote! { quote! {
rt_ref.block_on(#fn_path(#(#param),*))? rt_ref.block_on(#fn_path(#(#cli_param),*))?
} }
} else { } else {
quote! { quote! {
#fn_path(#(#param),*)? #fn_path(#(#cli_param),*)?
} }
}; };
let display_res = if let Some(display_fn) = &opt.display { let display_res = if let Some(display_fn) = &opt.display {
@@ -857,7 +933,8 @@ fn cli_handler(
}; };
quote! { quote! {
pub fn cli_handler#generics( pub fn cli_handler#generics(
ctx: #ctx_ty, ctx: GenericContext,
parent_data: #parent_data_ty,
mut rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>, mut rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>,
matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>, matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>,
_method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>, _method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>,
@@ -898,7 +975,8 @@ fn cli_handler(
}; };
quote! { quote! {
pub fn cli_handler#generics( pub fn cli_handler#generics(
ctx: #ctx_ty, ctx: GenericContext,
parent_data: #parent_data_ty,
mut rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>, mut rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>,
matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>, matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>,
_method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>, _method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>,
@@ -921,11 +999,11 @@ fn cli_handler(
let cmd_preprocess = if common.is_async { let cmd_preprocess = if common.is_async {
quote! { quote! {
#create_rt #create_rt
let ctx = rt_ref.block_on(#fn_path(#(#param),*))?; let parent_data = rt_ref.block_on(#fn_path(#(#param),*))?;
} }
} else { } else {
quote! { quote! {
let ctx = #fn_path(#(#param),*)?; let parent_data = #fn_path(#(#param),*)?;
} }
}; };
let subcmd_impl = subcommands.iter().map(|subcommand| { let subcmd_impl = subcommands.iter().map(|subcommand| {
@@ -939,11 +1017,13 @@ fn cli_handler(
}; };
cli_handler.arguments = match cli_handler.arguments { cli_handler.arguments = match cli_handler.arguments {
PathArguments::None => PathArguments::AngleBracketed( PathArguments::None => PathArguments::AngleBracketed(
syn::parse2(quote! { ::<Params#param_ty_generics> }).unwrap(), syn::parse2(quote! { ::<Params#param_ty_generics, GenericContext> })
.unwrap(),
), ),
PathArguments::AngleBracketed(mut a) => { PathArguments::AngleBracketed(mut a) => {
a.args a.args
.push(syn::parse2(quote! { Params#param_ty_generics }).unwrap()); .push(syn::parse2(quote! { Params#param_ty_generics }).unwrap());
a.args.push(syn::parse2(quote! { GenericContext }).unwrap());
PathArguments::AngleBracketed(a) PathArguments::AngleBracketed(a)
} }
_ => unreachable!(), _ => unreachable!(),
@@ -955,26 +1035,34 @@ fn cli_handler(
} else { } else {
method + "." + #subcommand::NAME method + "." + #subcommand::NAME
}; };
#subcommand::#cli_handler(ctx, rt, sub_m, method, params) #subcommand::#cli_handler(ctx, parent_data, rt, sub_m, method, params)
}, },
} }
}); });
let self_impl = match (self_impl, &common.exec_ctx) { let self_impl = match (self_impl, &common.exec_ctx) {
(Some(self_impl), ExecutionContext::CliOnly(_)) => { (Some(self_impl), ExecutionContext::CliOnly(_))
let self_impl_fn = &self_impl.path; | (Some(self_impl), ExecutionContext::Local(_))
| (Some(self_impl), ExecutionContext::CustomCli { .. }) => {
let (self_impl_fn, is_async) =
if let ExecutionContext::CustomCli { cli, is_async, .. } = &common.exec_ctx
{
(cli, *is_async)
} else {
(&self_impl.path, self_impl.is_async)
};
let create_rt = if common.is_async { let create_rt = if common.is_async {
None None
} else { } else {
Some(create_rt) Some(create_rt)
}; };
let self_impl = if self_impl.is_async { let self_impl = if is_async {
quote_spanned! { self_impl_fn.span() => quote_spanned! { self_impl_fn.span() =>
#create_rt #create_rt
rt_ref.block_on(#self_impl_fn(ctx))? rt_ref.block_on(#self_impl_fn(Into::into(ctx), parent_data))?
} }
} else { } else {
quote_spanned! { self_impl_fn.span() => quote_spanned! { self_impl_fn.span() =>
#self_impl_fn(ctx)? #self_impl_fn(Into::into(ctx), parent_data)?
} }
}; };
quote! { quote! {
@@ -985,11 +1073,11 @@ fn cli_handler(
let self_impl_fn = &self_impl.path; let self_impl_fn = &self_impl.path;
let self_impl = if self_impl.is_async { let self_impl = if self_impl.is_async {
quote! { quote! {
rt_ref.block_on(#self_impl_fn(ctx)) rt_ref.block_on(#self_impl_fn(Into::into(ctx), parent_data))
} }
} else { } else {
quote! { quote! {
#self_impl_fn(ctx) #self_impl_fn(Into::into(ctx), parent_data)
} }
}; };
let create_rt = if common.is_async { let create_rt = if common.is_async {
@@ -1016,7 +1104,7 @@ fn cli_handler(
} }
} }
} }
_ => quote! { (None, _) | (Some(_), ExecutionContext::RpcOnly(_)) => quote! {
Err(::rpc_toolkit::command_helpers::prelude::RpcError { Err(::rpc_toolkit::command_helpers::prelude::RpcError {
data: Some(method.into()), data: Some(method.into()),
..::rpc_toolkit::command_helpers::prelude::yajrc::METHOD_NOT_FOUND_ERROR ..::rpc_toolkit::command_helpers::prelude::yajrc::METHOD_NOT_FOUND_ERROR
@@ -1025,7 +1113,8 @@ fn cli_handler(
}; };
quote! { quote! {
pub fn cli_handler#generics( pub fn cli_handler#generics(
ctx: #ctx_ty, ctx: GenericContext,
parent_data: #parent_data_ty,
mut rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>, mut rt: Option<::rpc_toolkit::command_helpers::prelude::Runtime>,
matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>, matches: &::rpc_toolkit::command_helpers::prelude::ArgMatches<'_>,
method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>, method: ::rpc_toolkit::command_helpers::prelude::Cow<'_, str>,
@@ -1071,6 +1160,17 @@ pub fn build(args: AttributeArgs, mut item: ItemFn) -> TokenStream {
.map(|a| a.span()) .map(|a| a.span())
.unwrap_or_else(Span::call_site), .unwrap_or_else(Span::call_site),
); );
let ctx_ty = params
.iter()
.find_map(|a| {
if let ParamType::Context(ty) = a {
Some(ty.clone())
} else {
None
}
})
.unwrap_or(macro_try!(syn::parse2(quote! { () })));
let ctx_trait = ctx_trait(ctx_ty, &opt);
let metadata = metadata(&mut opt); let metadata = metadata(&mut opt);
let build_app = build_app(command_name_str.clone(), &mut opt, &mut params); let build_app = build_app(command_name_str.clone(), &mut opt, &mut params);
let rpc_handler = rpc_handler(fn_name, fn_generics, &opt, &params); let rpc_handler = rpc_handler(fn_name, fn_generics, &opt, &params);
@@ -1084,6 +1184,8 @@ pub fn build(args: AttributeArgs, mut item: ItemFn) -> TokenStream {
pub const NAME: &'static str = #command_name_str; pub const NAME: &'static str = #command_name_str;
pub const ASYNC: bool = #is_async; pub const ASYNC: bool = #is_async;
#ctx_trait
#metadata #metadata
#build_app #build_app

View File

@@ -96,6 +96,7 @@ pub enum ParamType {
None, None,
Arg(ArgOptions), Arg(ArgOptions),
Context(Type), Context(Type),
ParentData(Type),
Request, Request,
Response, Response,
} }

View File

@@ -774,6 +774,11 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result<Vec<ParamType>> {
attr.span(), attr.span(),
"`arg` and `context` are mutually exclusive", "`arg` and `context` are mutually exclusive",
)); ));
} else if matches!(ty, ParamType::ParentData(_)) {
return Err(Error::new(
attr.span(),
"`arg` and `parent_data` are mutually exclusive",
));
} else if matches!(ty, ParamType::Request) { } else if matches!(ty, ParamType::Request) {
return Err(Error::new( return Err(Error::new(
attr.span(), attr.span(),
@@ -799,6 +804,11 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result<Vec<ParamType>> {
attr.span(), attr.span(),
"`arg` and `context` are mutually exclusive", "`arg` and `context` are mutually exclusive",
)); ));
} else if matches!(ty, ParamType::ParentData(_)) {
return Err(Error::new(
attr.span(),
"`context` and `parent_data` are mutually exclusive",
));
} else if matches!(ty, ParamType::Request) { } else if matches!(ty, ParamType::Request) {
return Err(Error::new( return Err(Error::new(
attr.span(), attr.span(),
@@ -810,6 +820,36 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result<Vec<ParamType>> {
"`context` and `response` are mutually exclusive", "`context` and `response` are mutually exclusive",
)); ));
} }
} else if param.attrs[i].path.is_ident("parent_data") {
let attr = param.attrs.remove(i);
if matches!(ty, ParamType::None) {
ty = ParamType::ParentData(*param.ty.clone());
} else if matches!(ty, ParamType::ParentData(_)) {
return Err(Error::new(
attr.span(),
"`parent_data` attribute may only be specified once",
));
} else if matches!(ty, ParamType::Arg(_)) {
return Err(Error::new(
attr.span(),
"`arg` and `parent_data` are mutually exclusive",
));
} else if matches!(ty, ParamType::Context(_)) {
return Err(Error::new(
attr.span(),
"`context` and `parent_data` are mutually exclusive",
));
} else if matches!(ty, ParamType::Request) {
return Err(Error::new(
attr.span(),
"`parent_data` and `request` are mutually exclusive",
));
} else if matches!(ty, ParamType::Response) {
return Err(Error::new(
attr.span(),
"`parent_data` and `response` are mutually exclusive",
));
}
} else if param.attrs[i].path.is_ident("request") { } else if param.attrs[i].path.is_ident("request") {
let attr = param.attrs.remove(i); let attr = param.attrs.remove(i);
if matches!(ty, ParamType::None) { if matches!(ty, ParamType::None) {
@@ -829,6 +869,11 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result<Vec<ParamType>> {
attr.span(), attr.span(),
"`context` and `request` are mutually exclusive", "`context` and `request` are mutually exclusive",
)); ));
} else if matches!(ty, ParamType::ParentData(_)) {
return Err(Error::new(
attr.span(),
"`parent_data` and `request` are mutually exclusive",
));
} else if matches!(ty, ParamType::Response) { } else if matches!(ty, ParamType::Response) {
return Err(Error::new( return Err(Error::new(
attr.span(), attr.span(),
@@ -854,6 +899,11 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result<Vec<ParamType>> {
attr.span(), attr.span(),
"`context` and `response` are mutually exclusive", "`context` and `response` are mutually exclusive",
)); ));
} else if matches!(ty, ParamType::Context(_)) {
return Err(Error::new(
attr.span(),
"`parent_data` and `response` are mutually exclusive",
));
} else if matches!(ty, ParamType::Request) { } else if matches!(ty, ParamType::Request) {
return Err(Error::new( return Err(Error::new(
attr.span(), attr.span(),
@@ -867,7 +917,7 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result<Vec<ParamType>> {
if matches!(ty, ParamType::None) { if matches!(ty, ParamType::None) {
return Err(Error::new( return Err(Error::new(
param.span(), param.span(),
"must specify either `arg` or `context` attributes", "must specify either `arg`, `context`, `parent_data`, `request`, or `response` attributes",
)); ));
} }
params.push(ty) params.push(ty)

View File

@@ -6,16 +6,24 @@ use super::*;
pub fn build(args: RpcServerArgs) -> TokenStream { pub fn build(args: RpcServerArgs) -> TokenStream {
let mut command = args.command; let mut command = args.command;
let arguments = std::mem::replace( let mut arguments = std::mem::replace(
&mut command.segments.last_mut().unwrap().arguments, &mut command.segments.last_mut().unwrap().arguments,
PathArguments::None, PathArguments::None,
); );
let command_module = command.clone(); let command_module = command.clone();
if let PathArguments::AngleBracketed(a) = &mut arguments {
a.args.push(syn::parse2(quote! { _ }).unwrap());
}
command.segments.push(PathSegment { command.segments.push(PathSegment {
ident: Ident::new("rpc_handler", command.span()), ident: Ident::new("rpc_handler", command.span()),
arguments, arguments,
}); });
let ctx = args.ctx; let ctx = args.ctx;
let parent_data = if let Some(data) = args.parent_data {
quote! { #data }
} else {
quote! { () }
};
let status_fn = args.status_fn.unwrap_or_else(|| { let status_fn = args.status_fn.unwrap_or_else(|| {
syn::parse2(quote! { |_| ::rpc_toolkit::hyper::StatusCode::OK }).unwrap() syn::parse2(quote! { |_| ::rpc_toolkit::hyper::StatusCode::OK }).unwrap()
}); });
@@ -47,6 +55,7 @@ pub fn build(args: RpcServerArgs) -> TokenStream {
let res = quote! { let res = quote! {
{ {
let ctx = #ctx; let ctx = #ctx;
let parent_data = #parent_data;
let status_fn = #status_fn; let status_fn = #status_fn;
let builder = ::rpc_toolkit::rpc_server_helpers::make_builder(&ctx); let builder = ::rpc_toolkit::rpc_server_helpers::make_builder(&ctx);
#( #(
@@ -54,12 +63,14 @@ pub fn build(args: RpcServerArgs) -> TokenStream {
)* )*
let make_svc = ::rpc_toolkit::hyper::service::make_service_fn(move |_| { let make_svc = ::rpc_toolkit::hyper::service::make_service_fn(move |_| {
let ctx = ctx.clone(); let ctx = ctx.clone();
let parent_data = parent_data.clone();
#( #(
let #middleware_name_clone3 = #middleware_name_clone2.clone(); let #middleware_name_clone3 = #middleware_name_clone2.clone();
)* )*
async move { async move {
Ok::<_, ::rpc_toolkit::hyper::Error>(::rpc_toolkit::hyper::service::service_fn(move |mut req| { Ok::<_, ::rpc_toolkit::hyper::Error>(::rpc_toolkit::hyper::service::service_fn(move |mut req| {
let ctx = ctx.clone(); let ctx = ctx.clone();
let parent_data = parent_data.clone();
let metadata = #command_module::Metadata::default(); let metadata = #command_module::Metadata::default();
#( #(
let #middleware_name_clone5 = #middleware_name_clone4.clone(); let #middleware_name_clone5 = #middleware_name_clone4.clone();
@@ -83,7 +94,7 @@ pub fn build(args: RpcServerArgs) -> TokenStream {
}; };
)* )*
let mut rpc_res = match ::rpc_toolkit::serde_json::from_value(::rpc_toolkit::serde_json::Value::Object(rpc_req.params)) { let mut rpc_res = match ::rpc_toolkit::serde_json::from_value(::rpc_toolkit::serde_json::Value::Object(rpc_req.params)) {
Ok(params) => #command(ctx, &req_parts, &mut res_parts, ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), params).await, Ok(params) => #command(ctx, parent_data, &req_parts, &mut res_parts, ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), params).await,
Err(e) => Err(e.into()) Err(e) => Err(e.into())
}; };
#( #(

View File

@@ -3,6 +3,7 @@ use syn::*;
pub struct RpcServerArgs { pub struct RpcServerArgs {
command: Path, command: Path,
ctx: Expr, ctx: Expr,
parent_data: Option<Expr>,
status_fn: Option<Expr>, status_fn: Option<Expr>,
middleware: punctuated::Punctuated<Expr, token::Comma>, middleware: punctuated::Punctuated<Expr, token::Comma>,
} }

View File

@@ -9,6 +9,7 @@ impl Parse for RpcServerArgs {
braced!(args in input); braced!(args in input);
let mut command = None; let mut command = None;
let mut ctx = None; let mut ctx = None;
let mut parent_data = None;
let mut status_fn = None; let mut status_fn = None;
let mut middleware = Punctuated::new(); let mut middleware = Punctuated::new();
while !args.is_empty() { while !args.is_empty() {
@@ -21,6 +22,9 @@ impl Parse for RpcServerArgs {
"context" => { "context" => {
ctx = Some(args.parse()?); ctx = Some(args.parse()?);
} }
"parent_data" => {
parent_data = Some(args.parse()?);
}
"status" => { "status" => {
status_fn = Some(args.parse()?); status_fn = Some(args.parse()?);
} }
@@ -38,6 +42,7 @@ impl Parse for RpcServerArgs {
Ok(RpcServerArgs { Ok(RpcServerArgs {
command: command.expect("`command` is required"), command: command.expect("`command` is required"),
ctx: ctx.expect("`context` is required"), ctx: ctx.expect("`context` is required"),
parent_data,
status_fn, status_fn,
middleware, middleware,
}) })

View File

@@ -13,6 +13,7 @@ pub fn build(args: RunCliArgs) -> TokenStream {
let command = command_handler.clone(); let command = command_handler.clone();
if let PathArguments::AngleBracketed(a) = &mut arguments { if let PathArguments::AngleBracketed(a) = &mut arguments {
a.args.push(syn::parse2(quote! { () }).unwrap()); a.args.push(syn::parse2(quote! { () }).unwrap());
a.args.push(syn::parse2(quote! { _ }).unwrap());
} }
command_handler.segments.push(PathSegment { command_handler.segments.push(PathSegment {
ident: Ident::new("cli_handler", command.span()), ident: Ident::new("cli_handler", command.span()),
@@ -42,6 +43,11 @@ pub fn build(args: RunCliArgs) -> TokenStream {
} else { } else {
quote! { &rpc_toolkit_matches } quote! { &rpc_toolkit_matches }
}; };
let parent_data = if let Some(data) = args.parent_data {
quote! { #data }
} else {
quote! { () }
};
let exit_fn = args.exit_fn.unwrap_or_else(|| { let exit_fn = args.exit_fn.unwrap_or_else(|| {
syn::parse2(quote! { |err: ::rpc_toolkit::yajrc::RpcError| { syn::parse2(quote! { |err: ::rpc_toolkit::yajrc::RpcError| {
eprintln!("{}", err.message); eprintln!("{}", err.message);
@@ -56,8 +62,10 @@ pub fn build(args: RunCliArgs) -> TokenStream {
{ {
let rpc_toolkit_matches = #app.get_matches(); let rpc_toolkit_matches = #app.get_matches();
let rpc_toolkit_ctx = #make_ctx; let rpc_toolkit_ctx = #make_ctx;
let rpc_toolkit_parent_data = #parent_data;
if let Err(err) = #command_handler( if let Err(err) = #command_handler(
rpc_toolkit_ctx, rpc_toolkit_ctx,
rpc_toolkit_parent_data,
None, None,
&rpc_toolkit_matches, &rpc_toolkit_matches,
"".into(), "".into(),

View File

@@ -14,6 +14,7 @@ pub struct RunCliArgs {
command: Path, command: Path,
mut_app: Option<MutApp>, mut_app: Option<MutApp>,
make_ctx: Option<MakeCtx>, make_ctx: Option<MakeCtx>,
parent_data: Option<Expr>,
exit_fn: Option<Expr>, exit_fn: Option<Expr>,
} }

View File

@@ -45,6 +45,14 @@ impl Parse for RunCliArgs {
if !input.is_empty() { if !input.is_empty() {
let _: token::Comma = input.parse()?; let _: token::Comma = input.parse()?;
} }
if !input.is_empty() {
let _: token::Comma = input.parse()?;
}
let parent_data = if !input.is_empty() {
Some(input.parse()?)
} else {
None
};
let exit_fn = if !input.is_empty() { let exit_fn = if !input.is_empty() {
Some(input.parse()?) Some(input.parse()?)
} else { } else {
@@ -54,6 +62,7 @@ impl Parse for RunCliArgs {
command, command,
mut_app, mut_app,
make_ctx, make_ctx,
parent_data,
exit_fn, exit_fn,
}) })
} }

View File

@@ -33,3 +33,9 @@ pub trait Context {
} }
impl Context for () {} impl Context for () {}
impl<'a, T: Context + 'a> From<T> for Box<dyn Context + 'a> {
fn from(ctx: T) -> Self {
Box::new(ctx)
}
}

View File

@@ -8,8 +8,7 @@ use rpc_toolkit::clap::Arg;
use rpc_toolkit::hyper::http::Error as HttpError; use rpc_toolkit::hyper::http::Error as HttpError;
use rpc_toolkit::hyper::{Body, Response}; use rpc_toolkit::hyper::{Body, Response};
use rpc_toolkit::rpc_server_helpers::{ use rpc_toolkit::rpc_server_helpers::{
constrain_middleware, DynMiddleware, DynMiddlewareStage2, DynMiddlewareStage3, DynMiddlewareStage2, DynMiddlewareStage3, DynMiddlewareStage4,
DynMiddlewareStage4,
}; };
use rpc_toolkit::serde::{Deserialize, Serialize}; use rpc_toolkit::serde::{Deserialize, Serialize};
use rpc_toolkit::url::Host; use rpc_toolkit::url::Host;
@@ -17,34 +16,29 @@ use rpc_toolkit::yajrc::RpcError;
use rpc_toolkit::{command, rpc_server, run_cli, Context, Metadata}; use rpc_toolkit::{command, rpc_server, run_cli, Context, Metadata};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct AppState<T, U> { pub struct AppState(Arc<ConfigSeed>);
seed: T, impl From<AppState> for () {
data: U, fn from(_: AppState) -> Self {
} ()
impl<T, U> AppState<T, U> {
pub fn map<F: FnOnce(U) -> V, V>(self, f: F) -> AppState<T, V> {
AppState {
seed: self.seed,
data: f(self.data),
}
} }
} }
#[derive(Debug)]
pub struct ConfigSeed { pub struct ConfigSeed {
host: Host, host: Host,
port: u16, port: u16,
} }
impl<T> Context for AppState<Arc<ConfigSeed>, T> { impl Context for AppState {
fn host(&self) -> Host<&str> { fn host(&self) -> Host<&str> {
match &self.seed.host { match &self.0.host {
Host::Domain(s) => Host::Domain(s.as_str()), Host::Domain(s) => Host::Domain(s.as_str()),
Host::Ipv4(i) => Host::Ipv4(*i), Host::Ipv4(i) => Host::Ipv4(*i),
Host::Ipv6(i) => Host::Ipv6(*i), Host::Ipv6(i) => Host::Ipv6(*i),
} }
} }
fn port(&self) -> u16 { fn port(&self) -> u16 {
self.seed.port self.0.port
} }
} }
@@ -53,41 +47,43 @@ impl<T> Context for AppState<Arc<ConfigSeed>, T> {
subcommands("dothething2::<U, E>", self(dothething_impl(async))) subcommands("dothething2::<U, E>", self(dothething_impl(async)))
)] )]
async fn dothething< async fn dothething<
U: Serialize + for<'a> Deserialize<'a> + FromStr<Err = E> + Clone, U: Serialize + for<'a> Deserialize<'a> + FromStr<Err = E> + Clone + 'static,
E: Display, E: Display,
>( >(
#[context] ctx: AppState<Arc<ConfigSeed>, ()>, #[context] _ctx: AppState,
#[arg(short = "a")] arg1: Option<String>, #[arg(short = "a")] arg1: Option<String>,
#[arg(short = "b")] val: String, #[arg(short = "b")] val: String,
#[arg(short = "c", help = "I am the flag `c`!")] arg3: bool, #[arg(short = "c", help = "I am the flag `c`!")] arg3: bool,
#[arg(stdin)] structured: U, #[arg(stdin)] structured: U,
) -> Result<AppState<Arc<ConfigSeed>, (Option<String>, String, bool, U)>, RpcError> { ) -> Result<(Option<String>, String, bool, U), RpcError> {
Ok(ctx.map(|_| (arg1, val, arg3, structured))) Ok((arg1, val, arg3, structured))
} }
async fn dothething_impl<U: Serialize>( async fn dothething_impl<U: Serialize>(
ctx: AppState<Arc<ConfigSeed>, (Option<String>, String, bool, U)>, ctx: AppState,
parent_data: (Option<String>, String, bool, U),
) -> Result<String, RpcError> { ) -> Result<String, RpcError> {
Ok(format!( Ok(format!(
"{:?}, {}, {}, {}", "{:?}, {:?}, {}, {}, {}",
ctx.data.0, ctx,
ctx.data.1, parent_data.0,
ctx.data.2, parent_data.1,
serde_json::to_string_pretty(&ctx.data.3)? parent_data.2,
serde_json::to_string_pretty(&parent_data.3)?
)) ))
} }
#[command(about = "Does the thing")] #[command(about = "Does the thing")]
fn dothething2<U: Serialize + for<'a> Deserialize<'a> + FromStr<Err = E>, E: Display>( fn dothething2<U: Serialize + for<'a> Deserialize<'a> + FromStr<Err = E>, E: Display>(
#[context] ctx: AppState<Arc<ConfigSeed>, (Option<String>, String, bool, U)>, #[parent_data] parent_data: (Option<String>, String, bool, U),
#[arg(stdin)] structured2: U, #[arg(stdin)] structured2: U,
) -> Result<String, RpcError> { ) -> Result<String, RpcError> {
Ok(format!( Ok(format!(
"{:?}, {}, {}, {}, {}", "{:?}, {}, {}, {}, {}",
ctx.data.0, parent_data.0,
ctx.data.1, parent_data.1,
ctx.data.2, parent_data.2,
serde_json::to_string_pretty(&ctx.data.3)?, serde_json::to_string_pretty(&parent_data.3)?,
serde_json::to_string_pretty(&structured2)?, serde_json::to_string_pretty(&structured2)?,
)) ))
} }
@@ -134,7 +130,7 @@ async fn test_rpc() {
}); });
let server = rpc_server!({ let server = rpc_server!({
command: dothething::<String, _>, command: dothething::<String, _>,
context: AppState { seed, data: () }, context: AppState(seed),
middleware: [ middleware: [
cors, cors,
], ],
@@ -192,14 +188,8 @@ fn cli_test() {
host: Host::parse("localhost").unwrap(), host: Host::parse("localhost").unwrap(),
port: 8000, port: 8000,
}); });
dothething::cli_handler::<String, _, _>( dothething::cli_handler::<String, _, _, _>(AppState(seed), (), None, &matches, "".into(), ())
AppState { seed, data: () }, .unwrap();
None,
&matches,
"".into(),
(),
)
.unwrap();
} }
#[test] #[test]
@@ -210,14 +200,11 @@ fn cli_example() {
app => app app => app
.arg(Arg::with_name("host").long("host").short("h").takes_value(true)) .arg(Arg::with_name("host").long("host").short("h").takes_value(true))
.arg(Arg::with_name("port").long("port").short("p").takes_value(true)), .arg(Arg::with_name("port").long("port").short("p").takes_value(true)),
matches => AppState { seed: Arc::new(ConfigSeed { matches => AppState(Arc::new(ConfigSeed {
host: Host::parse(matches.value_of("host").unwrap_or("localhost")).unwrap(), host: Host::parse(matches.value_of("host").unwrap_or("localhost")).unwrap(),
port: matches.value_of("port").unwrap_or("8000").parse().unwrap(), port: matches.value_of("port").unwrap_or("8000").parse().unwrap(),
}), data: () } }))
) )
} }
fn type_check() { ////////////////////////////////////////////////
let middleware: DynMiddleware<dothething::Metadata> = todo!();
constrain_middleware(&middleware);
}