diff --git a/rpc-toolkit-macro-internals/src/command/build.rs b/rpc-toolkit-macro-internals/src/command/build.rs index 297d0d9..6c6778b 100644 --- a/rpc-toolkit-macro-internals/src/command/build.rs +++ b/rpc-toolkit-macro-internals/src/command/build.rs @@ -10,6 +10,172 @@ use syn::token::{Comma, Where}; use super::parse::*; use super::*; +fn metadata(full_options: &Options) -> TokenStream { + let options = match full_options { + Options::Leaf(a) => a, + Options::Parent(ParentOptions { common, .. }) => common, + }; + let fallthrough = |ty: &str| { + let getter_name = Ident::new(&format!("get_{}", ty), Span::call_site()); + match &*full_options { + Options::Parent(ParentOptions { subcommands, .. }) => { + let subcmd_handler = subcommands.iter().map(|subcmd| { + let mut subcmd = subcmd.clone(); + subcmd.segments.last_mut().unwrap().arguments = PathArguments::None; + quote_spanned!{ subcmd.span() => + [#subcmd::NAME, rest] => if let Some(val) = #subcmd::Metadata.#getter_name(rest, key) { + return Some(val); + }, + } + }); + quote! { + if !command.is_empty() { + match command.splitn(2, ".").chain(std::iter::repeat("")).take(2).collect::>().as_slice() { + #( + #subcmd_handler + )* + _ => () + } + } + } + } + _ => quote! {}, + } + }; + fn impl_getter>( + ty: &str, + metadata: I, + fallthrough: TokenStream, + ) -> TokenStream { + let getter_name = Ident::new(&format!("get_{}", ty), Span::call_site()); + let ty: Type = syn::parse_str(ty).unwrap(); + quote! { + fn #getter_name(&self, command: &str, key: &str) -> Option<#ty> { + #fallthrough + match key { + #(#metadata)* + _ => None, + } + } + } + } + let bool_metadata = options + .metadata + .iter() + .filter(|(_, lit)| matches!(lit, Lit::Bool(_))) + .map(|(name, value)| { + let name = LitStr::new(&name.to_string(), name.span()); + quote! { + #name => Some(#value), + } + }); + let number_metadata = |ty: &str| { + let ty: Type = syn::parse_str(ty).unwrap(); + options + .metadata + .iter() + .filter(|(_, lit)| matches!(lit, Lit::Int(_) | Lit::Float(_) | Lit::Byte(_))) + .map(move |(name, value)| { + let name = LitStr::new(&name.to_string(), name.span()); + quote! { + #name => Some(#value as #ty), + } + }) + }; + let char_metadata = options + .metadata + .iter() + .filter(|(_, lit)| matches!(lit, Lit::Char(_))) + .map(|(name, value)| { + let name = LitStr::new(&name.to_string(), name.span()); + quote! { + #name => Some(#value), + } + }); + let str_metadata = options + .metadata + .iter() + .filter(|(_, lit)| matches!(lit, Lit::Str(_))) + .map(|(name, value)| { + let name = LitStr::new(&name.to_string(), name.span()); + quote! { + #name => Some(#value), + } + }); + let bstr_metadata = options + .metadata + .iter() + .filter(|(_, lit)| matches!(lit, Lit::ByteStr(_))) + .map(|(name, value)| { + let name = LitStr::new(&name.to_string(), name.span()); + quote! { + #name => Some(#value), + } + }); + + let bool_getter = impl_getter("bool", bool_metadata, fallthrough("bool")); + let u8_getter = impl_getter("u8", number_metadata("u8"), fallthrough("u8")); + let u16_getter = impl_getter("u16", number_metadata("u16"), fallthrough("u16")); + let u32_getter = impl_getter("u32", number_metadata("u32"), fallthrough("u32")); + let u64_getter = impl_getter("u64", number_metadata("u64"), fallthrough("u64")); + let usize_getter = impl_getter("usize", number_metadata("usize"), fallthrough("usize")); + let i8_getter = impl_getter("i8", number_metadata("i8"), fallthrough("i8")); + let i16_getter = impl_getter("i16", number_metadata("i16"), fallthrough("i16")); + let i32_getter = impl_getter("i32", number_metadata("i32"), fallthrough("i32")); + let i64_getter = impl_getter("i64", number_metadata("i64"), fallthrough("i64")); + let isize_getter = impl_getter("isize", number_metadata("isize"), fallthrough("isize")); + let f32_getter = impl_getter("f32", number_metadata("f32"), fallthrough("f32")); + let f64_getter = impl_getter("f64", number_metadata("f64"), fallthrough("f64")); + let char_getter = impl_getter("char", char_metadata, fallthrough("char")); + let str_fallthrough = fallthrough("str"); + let str_getter = quote! { + fn get_str(&self, command: &str, key: &str) -> Option<&'static str> { + #str_fallthrough + match key { + #(#str_metadata)* + _ => None, + } + } + }; + let bstr_fallthrough = fallthrough("bstr"); + let bstr_getter = quote! { + fn get_bstr(&self, command: &str, key: &str) -> Option<&'static [u8]> { + #bstr_fallthrough + match key { + #(#bstr_metadata)* + _ => None, + } + } + }; + + let res = quote! { + #[derive(Clone, Copy, Default)] + pub struct Metadata; + + #[allow(overflowing_literals)] + impl ::rpc_toolkit::Metadata for Metadata { + #bool_getter + #u8_getter + #u16_getter + #u32_getter + #u64_getter + #usize_getter + #i8_getter + #i16_getter + #i32_getter + #i64_getter + #isize_getter + #f32_getter + #f64_getter + #char_getter + #str_getter + #bstr_getter + } + }; + // panic!("{}", res); + res +} + fn build_app(name: LitStr, opt: &mut Options, params: &mut [ParamType]) -> TokenStream { let about = opt.common().about.clone().into_iter(); let (subcommand, subcommand_required) = if let Options::Parent(opt) = opt { @@ -835,6 +1001,7 @@ pub fn build(args: AttributeArgs, mut item: ItemFn) -> TokenStream { .map(|a| a.span()) .unwrap_or_else(Span::call_site), ); + let metadata = metadata(&mut opt); let build_app = build_app(command_name_str.clone(), &mut opt, &mut params); let rpc_handler = rpc_handler(fn_name, fn_generics, &opt, ¶ms); let cli_handler = cli_handler(fn_name, fn_generics, &mut opt, ¶ms); @@ -847,6 +1014,8 @@ pub fn build(args: AttributeArgs, mut item: ItemFn) -> TokenStream { pub const NAME: &'static str = #command_name_str; pub const ASYNC: bool = #is_async; + #metadata + #build_app #rpc_handler diff --git a/rpc-toolkit-macro-internals/src/command/mod.rs b/rpc-toolkit-macro-internals/src/command/mod.rs index 35b5a04..ba8d4f2 100644 --- a/rpc-toolkit-macro-internals/src/command/mod.rs +++ b/rpc-toolkit-macro-internals/src/command/mod.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use syn::*; pub mod build; @@ -24,6 +26,7 @@ pub struct LeafOptions { rename: Option, exec_ctx: ExecutionContext, display: Option, + metadata: HashMap, } pub struct SelfImplInfo { diff --git a/rpc-toolkit-macro-internals/src/command/parse.rs b/rpc-toolkit-macro-internals/src/command/parse.rs index ffa76b1..422df68 100644 --- a/rpc-toolkit-macro-internals/src/command/parse.rs +++ b/rpc-toolkit-macro-internals/src/command/parse.rs @@ -305,6 +305,47 @@ pub fn parse_command_attr(args: AttributeArgs) -> Result { NestedMeta::Meta(Meta::Path(p)) if p.is_ident("rename") => { return Err(Error::new(p.span(), "`rename` must be assigned to")); } + NestedMeta::Meta(Meta::List(list)) if list.path.is_ident("metadata") => { + for meta in list.nested { + match meta { + NestedMeta::Meta(Meta::NameValue(metadata_pair)) => { + let ident = metadata_pair.path.get_ident().ok_or(Error::new( + metadata_pair.path.span(), + "must be an identifier", + ))?; + if opt + .common() + .metadata + .insert(ident.clone(), metadata_pair.lit) + .is_some() + { + return Err(Error::new( + ident.span(), + format!("duplicate metadata `{}`", ident), + )); + } + } + a => { + return Err(Error::new( + a.span(), + "`metadata` takes a list of identifiers to be assigned to", + )) + } + } + } + } + NestedMeta::Meta(Meta::Path(p)) if p.is_ident("metadata") => { + return Err(Error::new( + p.span(), + "`metadata` takes a list of identifiers to be assigned to", + )); + } + NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("metadata") => { + return Err(Error::new( + nv.path.span(), + "`metadata` cannot be assigned to", + )); + } _ => { return Err(Error::new(arg.span(), "unknown argument")); } diff --git a/rpc-toolkit-macro-internals/src/rpc_server/build.rs b/rpc-toolkit-macro-internals/src/rpc_server/build.rs index c81c629..48c901c 100644 --- a/rpc-toolkit-macro-internals/src/rpc_server/build.rs +++ b/rpc-toolkit-macro-internals/src/rpc_server/build.rs @@ -10,6 +10,7 @@ pub fn build(args: RpcServerArgs) -> TokenStream { &mut command.segments.last_mut().unwrap().arguments, PathArguments::None, ); + let command_module = command.clone(); command.segments.push(PathSegment { ident: Ident::new("rpc_handler", command.span()), arguments, @@ -22,14 +23,18 @@ pub fn build(args: RpcServerArgs) -> TokenStream { .map(|i| Ident::new(&format!("middleware_pre_{}", i), Span::call_site())) .take(args.middleware.len()); let middleware_name_pre2 = middleware_name_pre.clone(); - let middleware_name = (0..) - .map(|i| Ident::new(&format!("middleware_{}", i), Span::call_site())) + let middleware_name_post = (0..) + .map(|i| Ident::new(&format!("middleware_post_{}", i), Span::call_site())) .take(args.middleware.len()); - let middleware_name_inv = middleware_name + let middleware_name_post_inv = middleware_name_post .clone() .collect::>() .into_iter() .rev(); + let middleware_name = (0..) + .map(|i| Ident::new(&format!("middleware_{}", i), Span::call_site())) + .take(args.middleware.len()); + let middleware_name2 = middleware_name.clone(); let middleware = args.middleware.iter(); let res = quote! { { @@ -41,9 +46,10 @@ pub fn build(args: RpcServerArgs) -> TokenStream { async move { Ok::<_, ::rpc_toolkit::hyper::Error>(::rpc_toolkit::hyper::service::service_fn(move |mut req| { let ctx = ctx.clone(); + let metadata = #command_module::Metadata::default(); async move { #( - let #middleware_name_pre = match ::rpc_toolkit::rpc_server_helpers::constrain_middleware(#middleware)(&mut req).await? { + let #middleware_name_pre = match ::rpc_toolkit::rpc_server_helpers::constrain_middleware(#middleware)(&mut req, metadata).await? { Ok(a) => a, Err(res) => return Ok(res), }; @@ -52,28 +58,35 @@ pub fn build(args: RpcServerArgs) -> TokenStream { match rpc_req { Ok(mut rpc_req) => { #( - let #middleware_name = match #middleware_name_pre2(&mut rpc_req).await? { + let #middleware_name_post = match #middleware_name_pre2(&mut rpc_req).await? { Ok(a) => a, Err(res) => return Ok(res), }; )* - let mut rpc_res = ::rpc_toolkit::rpc_server_helpers::to_response( + let mut rpc_res = #command( + ctx, + ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), + rpc_req.params, + ) + .await; + #( + let #middleware_name = match #middleware_name_post_inv(&mut rpc_res).await? { + Ok(a) => a, + Err(res) => return Ok(res), + }; + )* + let mut res = ::rpc_toolkit::rpc_server_helpers::to_response( &req, Ok(( rpc_req.id, - #command( - ctx, - ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), - rpc_req.params, - ) - .await, + rpc_res, )), status_fn, )?; #( - #middleware_name_inv(&mut rpc_res).await?; + #middleware_name2(&mut res).await?; )* - Ok::<_, ::rpc_toolkit::hyper::http::Error>(rpc_res) + Ok::<_, ::rpc_toolkit::hyper::http::Error>(res) } Err(e) => ::rpc_toolkit::rpc_server_helpers::to_response( &req, diff --git a/rpc-toolkit/src/lib.rs b/rpc-toolkit/src/lib.rs index 962a993..894cfad 100644 --- a/rpc-toolkit/src/lib.rs +++ b/rpc-toolkit/src/lib.rs @@ -42,7 +42,9 @@ pub use rpc_toolkit_macro::run_cli; pub use {clap, hyper, reqwest, serde, serde_json, tokio, url, yajrc}; pub use crate::context::Context; +pub use crate::metadata::Metadata; pub mod command_helpers; mod context; +mod metadata; pub mod rpc_server_helpers; diff --git a/rpc-toolkit/src/metadata.rs b/rpc-toolkit/src/metadata.rs new file mode 100644 index 0000000..c300f76 --- /dev/null +++ b/rpc-toolkit/src/metadata.rs @@ -0,0 +1,68 @@ +macro_rules! getter_for { + ($($name:ident => $t:ty,)*) => { + $( + #[allow(unused_variables)] + fn $name(&self, command: &str, key: &str) -> Option<$t> { + None + } + )* + }; +} + +pub trait Metadata: Copy + Default + Send + Sync + 'static { + fn get(&self, command: &str, key: &str) -> Option { + Ty::from_metadata(self, command, key) + } + getter_for!( + get_bool => bool, + get_u8 => u8, + get_u16 => u16, + get_u32 => u32, + get_u64 => u64, + get_usize => usize, + get_i8 => i8, + get_i16 => i16, + get_i32 => i32, + get_i64 => i64, + get_isize => isize, + get_f32 => f32, + get_f64 => f64, + get_char => char, + get_str => &'static str, + get_bstr => &'static [u8], + ); +} + +macro_rules! impl_primitive_for { + ($($name:ident => $t:ty,)*) => { + $( + impl Primitive for $t { + fn from_metadata(m: &M, command: &str, key: &str) -> Option { + m.$name(command, key) + } + } + )* + }; +} + +pub trait Primitive: Copy { + fn from_metadata(m: &M, command: &str, key: &str) -> Option; +} +impl_primitive_for!( + get_bool => bool, + get_u8 => u8, + get_u16 => u16, + get_u32 => u32, + get_u64 => u64, + get_usize => usize, + get_i8 => i8, + get_i16 => i16, + get_i32 => i32, + get_i64 => i64, + get_isize => isize, + get_f32 => f32, + get_f64 => f64, + get_char => char, + get_str => &'static str, + get_bstr => &'static [u8], +); diff --git a/rpc-toolkit/src/rpc_server_helpers.rs b/rpc-toolkit/src/rpc_server_helpers.rs index ed4d626..174b435 100644 --- a/rpc-toolkit/src/rpc_server_helpers.rs +++ b/rpc-toolkit/src/rpc_server_helpers.rs @@ -1,6 +1,7 @@ use std::future::Future; use futures::future::BoxFuture; +use futures::FutureExt; use hyper::body::Buf; use hyper::http::Error as HttpError; use hyper::server::conn::AddrIncoming; @@ -12,7 +13,7 @@ use serde_json::Value; use url::Host; use yajrc::{AnyRpcMethod, GenericRpcMethod, Id, RpcError, RpcRequest, RpcResponse}; -use crate::Context; +use crate::{Context, Metadata}; lazy_static! { #[cfg(feature = "cbor")] @@ -100,15 +101,20 @@ pub fn to_response StatusCode>( } // &mut Request -> Result -> Future -> Future>, Response>, HttpError>>>, Response>, HttpError> -pub type DynMiddleware = Box< +pub type DynMiddleware = Box< dyn for<'a> FnOnce( &'a mut Request, + Metadata, ) -> BoxFuture< 'a, Result, Response>, HttpError>, > + Send + Sync, >; +pub fn noop Deserialize<'de> + 'static, M: Metadata>() -> DynMiddleware +{ + Box::new(|_, _| async { Ok(Ok(noop2())) }.boxed()) +} pub type DynMiddlewareStage2 = Box< dyn for<'a> FnOnce( &'a mut RpcRequest>, @@ -118,23 +124,45 @@ pub type DynMiddlewareStage2 = Box< > + Send + Sync, >; +pub fn noop2 Deserialize<'de> + 'static>() -> DynMiddlewareStage2 { + Box::new(|_| async { Ok(Ok(noop3())) }.boxed()) +} pub type DynMiddlewareStage3 = Box< + dyn for<'a> FnOnce( + &'a mut Result, + ) -> BoxFuture< + 'a, + Result>, HttpError>, + > + Send + + Sync, +>; +pub fn noop3() -> DynMiddlewareStage3 { + Box::new(|_| async { Ok(Ok(noop4())) }.boxed()) +} +pub type DynMiddlewareStage4 = Box< dyn for<'a> FnOnce(&'a mut Response) -> BoxFuture<'a, Result<(), HttpError>> + Send + Sync, >; +pub fn noop4() -> DynMiddlewareStage4 { + Box::new(|_| async { Ok(()) }.boxed()) +} pub fn constrain_middleware< 'a, 'b, 'c, + 'd, Params: for<'de> Deserialize<'de> + 'static, - ReqFn: Fn(&'a mut Request) -> ReqFut, + M: Metadata, + ReqFn: Fn(&'a mut Request, M) -> ReqFut, ReqFut: Future>, HttpError>> + 'a, RpcReqFn: FnOnce(&'b mut RpcRequest>) -> RpcReqFut, - RpcReqFut: Future>, HttpError>> + 'b, - ResFn: FnOnce(&'c mut Response) -> ResFut, - ResFut: Future> + 'c, + RpcReqFut: Future>, HttpError>> + 'b, + RpcResFn: FnOnce(&'c mut Result) -> RpcResFut, + RpcResFut: Future>, HttpError>> + 'c, + ResFn: FnOnce(&'d mut Response) -> ResFut, + ResFut: Future> + 'd, >( f: ReqFn, ) -> ReqFn { diff --git a/rpc-toolkit/tests/test.rs b/rpc-toolkit/tests/test.rs index c260621..65018ec 100644 --- a/rpc-toolkit/tests/test.rs +++ b/rpc-toolkit/tests/test.rs @@ -7,11 +7,13 @@ use hyper::Request; use rpc_toolkit::clap::Arg; use rpc_toolkit::hyper::http::Error as HttpError; use rpc_toolkit::hyper::{Body, Response}; -use rpc_toolkit::rpc_server_helpers::{DynMiddlewareStage2, DynMiddlewareStage3}; +use rpc_toolkit::rpc_server_helpers::{ + DynMiddlewareStage2, DynMiddlewareStage3, DynMiddlewareStage4, +}; use rpc_toolkit::serde::{Deserialize, Serialize}; use rpc_toolkit::url::Host; use rpc_toolkit::yajrc::RpcError; -use rpc_toolkit::{command, rpc_server, run_cli, Context}; +use rpc_toolkit::{command, rpc_server, run_cli, Context, Metadata}; #[derive(Debug, Clone)] pub struct AppState { @@ -89,21 +91,28 @@ fn dothething2 Deserialize<'a> + FromStr, E: Dis )) } -async fn cors Deserialize<'de> + 'static>( +async fn cors Deserialize<'de> + 'static, M: Metadata + 'static>( req: &mut Request, + _: M, ) -> Result, Response>, HttpError> { if req.method() == hyper::Method::OPTIONS { Ok(Err(Response::builder() .header("Access-Control-Allow-Origin", "*") .body(Body::empty())?)) } else { - Ok(Ok(Box::new(|_req| { + Ok(Ok(Box::new(|_| { async move { - let res: DynMiddlewareStage3 = Box::new(|res| { + let res: DynMiddlewareStage3 = Box::new(|_| { async move { - res.headers_mut() - .insert("Access-Control-Allow-Origin", "*".parse()?); - Ok::<_, HttpError>(()) + let res: DynMiddlewareStage4 = Box::new(|res| { + async move { + res.headers_mut() + .insert("Access-Control-Allow-Origin", "*".parse()?); + Ok::<_, HttpError>(()) + } + .boxed() + }); + Ok::<_, HttpError>(Ok(res)) } .boxed() });