From b449c4265a32ea4af88843fce4665e81a2c70d4f Mon Sep 17 00:00:00 2001 From: Aiden McClelland Date: Wed, 16 Jun 2021 15:29:06 -0600 Subject: [PATCH] middleware --- Cargo.lock | 108 ++++++++++++++---- .../src/rpc_server/build.rs | 75 +++++++++--- .../src/rpc_server/mod.rs | 1 + .../src/rpc_server/parse.rs | 45 ++++++-- rpc-toolkit/Cargo.toml | 15 +-- rpc-toolkit/src/rpc_server_helpers.rs | 43 ++++++- rpc-toolkit/tests/test.rs | 40 ++++++- 7 files changed, 266 insertions(+), 61 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bf1275f..9c9173d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -142,42 +142,97 @@ dependencies = [ ] [[package]] -name = "futures-channel" -version = "0.3.13" +name = "futures" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c2dd2df839b57db9ab69c2c9d8f3e8c81984781937fe2807dc6dcf3b2ad2939" +checksum = "0e7e43a803dae2fa37c1f6a8fe121e1f7bf9548b4dfc0522a42f34145dadfc27" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e682a68b29a882df0545c143dc3646daefe80ba479bcdede94d5a703de2871e2" dependencies = [ "futures-core", + "futures-sink", ] [[package]] name = "futures-core" -version = "0.3.13" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15496a72fabf0e62bdc3df11a59a3787429221dd0710ba8ef163d6f7a9112c94" +checksum = "0402f765d8a89a26043b889b26ce3c4679d268fa6bb22cd7c6aad98340e179d1" [[package]] -name = "futures-sink" -version = "0.3.13" +name = "futures-executor" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85754d98985841b7d4f5e8e6fbfa4a4ac847916893ec511a2917ccd8525b8bb3" - -[[package]] -name = "futures-task" -version = "0.3.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa189ef211c15ee602667a6fcfe1c1fd9e07d42250d2156382820fba33c9df80" - -[[package]] -name = "futures-util" -version = "0.3.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1812c7ab8aedf8d6f2701a43e1243acdbcc2b36ab26e2ad421eb99ac963d96d1" +checksum = "badaa6a909fac9e7236d0620a2f57f7664640c56575b71a7552fbd68deafab79" dependencies = [ "futures-core", "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acc499defb3b348f8d8f3f66415835a9131856ff7714bf10dadfc4ec4bdb29a1" + +[[package]] +name = "futures-macro" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4c40298486cdf52cc00cd6d6987892ba502c7656a16a4192a9992b1ccedd121" +dependencies = [ + "autocfg", + "proc-macro-hack", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a57bead0ceff0d6dde8f465ecd96c9338121bb7717d3e7b108059531870c4282" + +[[package]] +name = "futures-task" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a16bef9fc1a4dddb5bee51c989e3fbba26569cbb0e31f5b303c184e3dd33dae" + +[[package]] +name = "futures-util" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "feb5c238d27e2bf94ffdfd27b2c29e3df4a68c4193bb6427384259e2bf191967" +dependencies = [ + "autocfg", + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "proc-macro-hack", + "proc-macro-nested", + "slab", ] [[package]] @@ -574,6 +629,18 @@ version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857" +[[package]] +name = "proc-macro-hack" +version = "0.5.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" + +[[package]] +name = "proc-macro-nested" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086" + [[package]] name = "proc-macro2" version = "1.0.26" @@ -689,6 +756,7 @@ name = "rpc-toolkit" version = "0.1.0" dependencies = [ "clap", + "futures", "hyper", "lazy_static", "reqwest", diff --git a/rpc-toolkit-macro-internals/src/rpc_server/build.rs b/rpc-toolkit-macro-internals/src/rpc_server/build.rs index d96f034..c81c629 100644 --- a/rpc-toolkit-macro-internals/src/rpc_server/build.rs +++ b/rpc-toolkit-macro-internals/src/rpc_server/build.rs @@ -1,4 +1,4 @@ -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::spanned::Spanned; @@ -18,7 +18,20 @@ pub fn build(args: RpcServerArgs) -> TokenStream { let status_fn = args.status_fn.unwrap_or_else(|| { syn::parse2(quote! { |_| ::rpc_toolkit::hyper::StatusCode::OK }).unwrap() }); - quote! { + let middleware_name_pre = (0..) + .map(|i| Ident::new(&format!("middleware_pre_{}", i), Span::call_site())) + .take(args.middleware.len()); + let middleware_name_pre2 = middleware_name_pre.clone(); + let middleware_name = (0..) + .map(|i| Ident::new(&format!("middleware_{}", i), Span::call_site())) + .take(args.middleware.len()); + let middleware_name_inv = middleware_name + .clone() + .collect::>() + .into_iter() + .rev(); + let middleware = args.middleware.iter(); + let res = quote! { { let ctx = #ctx; let status_fn = #status_fn; @@ -29,28 +42,52 @@ pub fn build(args: RpcServerArgs) -> TokenStream { Ok::<_, ::rpc_toolkit::hyper::Error>(::rpc_toolkit::hyper::service::service_fn(move |mut req| { let ctx = ctx.clone(); async move { + #( + let #middleware_name_pre = match ::rpc_toolkit::rpc_server_helpers::constrain_middleware(#middleware)(&mut req).await? { + Ok(a) => a, + Err(res) => return Ok(res), + }; + )* let rpc_req = ::rpc_toolkit::rpc_server_helpers::make_request(&mut req).await; - ::rpc_toolkit::rpc_server_helpers::to_response( - &req, - match rpc_req { - Ok(rpc_req) => Ok(( - rpc_req.id, - #command( - ctx, - ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), - rpc_req.params, - ) - .await, - )), - Err(e) => Err(e), - }, - status_fn, - ) + match rpc_req { + Ok(mut rpc_req) => { + #( + let #middleware_name = match #middleware_name_pre2(&mut rpc_req).await? { + Ok(a) => a, + Err(res) => return Ok(res), + }; + )* + let mut rpc_res = ::rpc_toolkit::rpc_server_helpers::to_response( + &req, + Ok(( + rpc_req.id, + #command( + ctx, + ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), + rpc_req.params, + ) + .await, + )), + status_fn, + )?; + #( + #middleware_name_inv(&mut rpc_res).await?; + )* + Ok::<_, ::rpc_toolkit::hyper::http::Error>(rpc_res) + } + Err(e) => ::rpc_toolkit::rpc_server_helpers::to_response( + &req, + Err(e), + status_fn, + ), + } } })) } }); builder.serve(make_svc) } - } + }; + // panic!("{}", res); + res } diff --git a/rpc-toolkit-macro-internals/src/rpc_server/mod.rs b/rpc-toolkit-macro-internals/src/rpc_server/mod.rs index debe917..4e8f703 100644 --- a/rpc-toolkit-macro-internals/src/rpc_server/mod.rs +++ b/rpc-toolkit-macro-internals/src/rpc_server/mod.rs @@ -4,6 +4,7 @@ pub struct RpcServerArgs { command: Path, ctx: Expr, status_fn: Option, + middleware: punctuated::Punctuated, } pub mod build; diff --git a/rpc-toolkit-macro-internals/src/rpc_server/parse.rs b/rpc-toolkit-macro-internals/src/rpc_server/parse.rs index b959298..3d03afd 100644 --- a/rpc-toolkit-macro-internals/src/rpc_server/parse.rs +++ b/rpc-toolkit-macro-internals/src/rpc_server/parse.rs @@ -1,24 +1,45 @@ use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; use super::*; impl Parse for RpcServerArgs { fn parse(input: ParseStream) -> Result { - let command = input.parse()?; - let _: token::Comma = input.parse()?; - let ctx = input.parse()?; - if !input.is_empty() { - let _: token::Comma = input.parse()?; + let args; + braced!(args in input); + let mut command = None; + let mut ctx = None; + let mut status_fn = None; + let mut middleware = Punctuated::new(); + while !args.is_empty() { + let arg_name: syn::Ident = args.parse()?; + let _: token::Colon = args.parse()?; + match arg_name.to_string().as_str() { + "command" => { + command = Some(args.parse()?); + } + "context" => { + ctx = Some(args.parse()?); + } + "status" => { + status_fn = Some(args.parse()?); + } + "middleware" => { + let middlewares; + bracketed!(middlewares in args); + middleware = middlewares.parse_terminated(Expr::parse)?; + } + _ => return Err(Error::new(arg_name.span(), "unknown argument")), + } + if !args.is_empty() { + let _: token::Comma = args.parse()?; + } } - let status_fn = if !input.is_empty() { - Some(input.parse()?) - } else { - None - }; Ok(RpcServerArgs { - command, - ctx, + command: command.expect("`command` is required"), + ctx: ctx.expect("`context` is required"), status_fn, + middleware, }) } } diff --git a/rpc-toolkit/Cargo.toml b/rpc-toolkit/Cargo.toml index a10a543..9ac5c66 100644 --- a/rpc-toolkit/Cargo.toml +++ b/rpc-toolkit/Cargo.toml @@ -11,14 +11,15 @@ default = ["cbor"] [dependencies] clap = "2.33.3" -hyper = { version = "0.14.5", features = ["server", "http1", "http2", "tcp", "stream", "client"] } +futures = "0.3.15" +hyper = { version="0.14.5", features=["server", "http1", "http2", "tcp", "stream", "client"] } lazy_static = "1.4.0" -reqwest = { version = "0.11.2" } -rpc-toolkit-macro = { path = "../rpc-toolkit-macro" } -serde = { version = "1.0.125", features = ["derive"] } -serde_cbor = { version = "0.11.1", optional = true } +reqwest = { version="0.11.2" } +rpc-toolkit-macro = { path="../rpc-toolkit-macro" } +serde = { version="1.0.125", features=["derive"] } +serde_cbor = { version="0.11.1", optional=true } serde_json = "1.0.64" thiserror = "1.0.24" -tokio = { version = "1.4.0", features = ["full"] } +tokio = { version="1.4.0", features=["full"] } url = "2.2.1" -yajrc = { version = "*", path = "../../yajrc" } +yajrc = { version="*", path="../../yajrc" } diff --git a/rpc-toolkit/src/rpc_server_helpers.rs b/rpc-toolkit/src/rpc_server_helpers.rs index a9ef7cf..4808b52 100644 --- a/rpc-toolkit/src/rpc_server_helpers.rs +++ b/rpc-toolkit/src/rpc_server_helpers.rs @@ -1,4 +1,8 @@ +use std::future::Future; + +use futures::future::BoxFuture; use hyper::body::Buf; +use hyper::http::Error as HttpError; use hyper::server::conn::AddrIncoming; use hyper::server::{Builder, Server}; use hyper::{Body, Request, Response, StatusCode}; @@ -58,7 +62,7 @@ pub fn to_response StatusCode>( req: &Request, res: Result<(Option, Result), RpcError>, status_code_fn: F, -) -> Result, hyper::http::Error> { +) -> Result, HttpError> { let rpc_res: RpcResponse = match res { Ok((id, result)) => RpcResponse { id, result }, Err(e) => e.into(), @@ -94,3 +98,40 @@ pub fn to_response StatusCode>( }); res.body(Body::from(body)) } + +pub type DynMiddleware<'a, 'b, 'c, Params> = Box< + dyn FnOnce( + &'a mut Request, + ) -> BoxFuture< + 'a, + Result, Response>, HttpError>, + > + Send + + Sync, +>; +pub type DynMiddlewareStage2<'a, 'b, Params> = Box< + dyn FnOnce( + &'a mut RpcRequest>, + ) + -> BoxFuture<'a, Result, Response>, HttpError>> + + Send + + Sync, +>; +pub type DynMiddlewareStage3<'a> = + Box) -> BoxFuture<'a, Result<(), HttpError>> + Send + Sync>; + +pub fn constrain_middleware< + 'a, + 'b, + 'c, + Params: for<'de> Deserialize<'de> + 'static, + ReqFn: Fn(&'a mut Request) -> ReqFut, + ReqFut: Future>, HttpError>> + 'a, + RpcReqFn: FnOnce(&'b mut RpcRequest>) -> RpcReqFut, + RpcReqFut: Future>, HttpError>> + 'b, + ResFn: FnOnce(&'c mut Response) -> ResFut, + ResFut: Future> + 'c, +>( + f: ReqFn, +) -> ReqFn { + f +} diff --git a/rpc-toolkit/tests/test.rs b/rpc-toolkit/tests/test.rs index 24faaec..ec50245 100644 --- a/rpc-toolkit/tests/test.rs +++ b/rpc-toolkit/tests/test.rs @@ -2,7 +2,12 @@ use std::fmt::Display; use std::str::FromStr; use std::sync::Arc; +use futures::FutureExt; +use hyper::Request; use rpc_toolkit::clap::Arg; +use rpc_toolkit::hyper::http::Error as HttpError; +use rpc_toolkit::hyper::{Body, Response}; +use rpc_toolkit::rpc_server_helpers::{DynMiddlewareStage2, DynMiddlewareStage3}; use rpc_toolkit::serde::{Deserialize, Serialize}; use rpc_toolkit::url::Host; use rpc_toolkit::yajrc::RpcError; @@ -84,15 +89,46 @@ fn dothething2 Deserialize<'a> + FromStr, E: Dis )) } +async fn cors<'a, 'b, Params: for<'de> Deserialize<'de> + 'static>( + req: &mut Request, +) -> Result, Response>, HttpError> { + if req.method() == hyper::Method::OPTIONS { + Ok(Err(Response::builder() + .header("Access-Control-Allow-Origin", "*") + .body(Body::empty())?)) + } else { + Ok(Ok(Box::new(|_req| { + async move { + let res: DynMiddlewareStage3 = Box::new(|res| { + async move { + res.headers_mut() + .insert("Access-Control-Allow-Origin", "*".parse()?); + Ok::<_, HttpError>(()) + } + .boxed() + }); + Ok::<_, HttpError>(Ok(res)) + } + .boxed() + }))) + } +} + #[tokio::test] -async fn test() { +async fn test_rpc() { use tokio::io::AsyncWriteExt; let seed = Arc::new(ConfigSeed { host: Host::parse("localhost").unwrap(), port: 8000, }); - let server = rpc_server!(dothething::, AppState { seed, data: () }); + let server = rpc_server!({ + command: dothething::, + context: AppState { seed, data: () }, + middleware: [ + cors, + ], + }); let handle = tokio::spawn(server); let mut cmd = tokio::process::Command::new("cargo") .arg("test")