middleware

This commit is contained in:
Aiden McClelland
2021-06-16 15:29:06 -06:00
parent fb2629d099
commit b449c4265a
7 changed files with 266 additions and 61 deletions

108
Cargo.lock generated
View File

@@ -142,42 +142,97 @@ dependencies = [
] ]
[[package]] [[package]]
name = "futures-channel" name = "futures"
version = "0.3.13" version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c2dd2df839b57db9ab69c2c9d8f3e8c81984781937fe2807dc6dcf3b2ad2939" checksum = "0e7e43a803dae2fa37c1f6a8fe121e1f7bf9548b4dfc0522a42f34145dadfc27"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e682a68b29a882df0545c143dc3646daefe80ba479bcdede94d5a703de2871e2"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-sink",
] ]
[[package]] [[package]]
name = "futures-core" name = "futures-core"
version = "0.3.13" version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15496a72fabf0e62bdc3df11a59a3787429221dd0710ba8ef163d6f7a9112c94" checksum = "0402f765d8a89a26043b889b26ce3c4679d268fa6bb22cd7c6aad98340e179d1"
[[package]] [[package]]
name = "futures-sink" name = "futures-executor"
version = "0.3.13" version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85754d98985841b7d4f5e8e6fbfa4a4ac847916893ec511a2917ccd8525b8bb3" checksum = "badaa6a909fac9e7236d0620a2f57f7664640c56575b71a7552fbd68deafab79"
[[package]]
name = "futures-task"
version = "0.3.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa189ef211c15ee602667a6fcfe1c1fd9e07d42250d2156382820fba33c9df80"
[[package]]
name = "futures-util"
version = "0.3.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1812c7ab8aedf8d6f2701a43e1243acdbcc2b36ab26e2ad421eb99ac963d96d1"
dependencies = [ dependencies = [
"futures-core", "futures-core",
"futures-task", "futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acc499defb3b348f8d8f3f66415835a9131856ff7714bf10dadfc4ec4bdb29a1"
[[package]]
name = "futures-macro"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4c40298486cdf52cc00cd6d6987892ba502c7656a16a4192a9992b1ccedd121"
dependencies = [
"autocfg",
"proc-macro-hack",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a57bead0ceff0d6dde8f465ecd96c9338121bb7717d3e7b108059531870c4282"
[[package]]
name = "futures-task"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a16bef9fc1a4dddb5bee51c989e3fbba26569cbb0e31f5b303c184e3dd33dae"
[[package]]
name = "futures-util"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "feb5c238d27e2bf94ffdfd27b2c29e3df4a68c4193bb6427384259e2bf191967"
dependencies = [
"autocfg",
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite", "pin-project-lite",
"pin-utils", "pin-utils",
"proc-macro-hack",
"proc-macro-nested",
"slab",
] ]
[[package]] [[package]]
@@ -574,6 +629,18 @@ version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857" checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857"
[[package]]
name = "proc-macro-hack"
version = "0.5.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5"
[[package]]
name = "proc-macro-nested"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086"
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.26" version = "1.0.26"
@@ -689,6 +756,7 @@ name = "rpc-toolkit"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"clap", "clap",
"futures",
"hyper", "hyper",
"lazy_static", "lazy_static",
"reqwest", "reqwest",

View File

@@ -1,4 +1,4 @@
use proc_macro2::TokenStream; use proc_macro2::{Span, TokenStream};
use quote::quote; use quote::quote;
use syn::spanned::Spanned; use syn::spanned::Spanned;
@@ -18,7 +18,20 @@ pub fn build(args: RpcServerArgs) -> TokenStream {
let status_fn = args.status_fn.unwrap_or_else(|| { let status_fn = args.status_fn.unwrap_or_else(|| {
syn::parse2(quote! { |_| ::rpc_toolkit::hyper::StatusCode::OK }).unwrap() syn::parse2(quote! { |_| ::rpc_toolkit::hyper::StatusCode::OK }).unwrap()
}); });
quote! { let middleware_name_pre = (0..)
.map(|i| Ident::new(&format!("middleware_pre_{}", i), Span::call_site()))
.take(args.middleware.len());
let middleware_name_pre2 = middleware_name_pre.clone();
let middleware_name = (0..)
.map(|i| Ident::new(&format!("middleware_{}", i), Span::call_site()))
.take(args.middleware.len());
let middleware_name_inv = middleware_name
.clone()
.collect::<Vec<_>>()
.into_iter()
.rev();
let middleware = args.middleware.iter();
let res = quote! {
{ {
let ctx = #ctx; let ctx = #ctx;
let status_fn = #status_fn; let status_fn = #status_fn;
@@ -29,28 +42,52 @@ pub fn build(args: RpcServerArgs) -> TokenStream {
Ok::<_, ::rpc_toolkit::hyper::Error>(::rpc_toolkit::hyper::service::service_fn(move |mut req| { Ok::<_, ::rpc_toolkit::hyper::Error>(::rpc_toolkit::hyper::service::service_fn(move |mut req| {
let ctx = ctx.clone(); let ctx = ctx.clone();
async move { async move {
#(
let #middleware_name_pre = match ::rpc_toolkit::rpc_server_helpers::constrain_middleware(#middleware)(&mut req).await? {
Ok(a) => a,
Err(res) => return Ok(res),
};
)*
let rpc_req = ::rpc_toolkit::rpc_server_helpers::make_request(&mut req).await; let rpc_req = ::rpc_toolkit::rpc_server_helpers::make_request(&mut req).await;
::rpc_toolkit::rpc_server_helpers::to_response( match rpc_req {
&req, Ok(mut rpc_req) => {
match rpc_req { #(
Ok(rpc_req) => Ok(( let #middleware_name = match #middleware_name_pre2(&mut rpc_req).await? {
rpc_req.id, Ok(a) => a,
#command( Err(res) => return Ok(res),
ctx, };
::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), )*
rpc_req.params, let mut rpc_res = ::rpc_toolkit::rpc_server_helpers::to_response(
) &req,
.await, Ok((
)), rpc_req.id,
Err(e) => Err(e), #command(
}, ctx,
status_fn, ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method),
) rpc_req.params,
)
.await,
)),
status_fn,
)?;
#(
#middleware_name_inv(&mut rpc_res).await?;
)*
Ok::<_, ::rpc_toolkit::hyper::http::Error>(rpc_res)
}
Err(e) => ::rpc_toolkit::rpc_server_helpers::to_response(
&req,
Err(e),
status_fn,
),
}
} }
})) }))
} }
}); });
builder.serve(make_svc) builder.serve(make_svc)
} }
} };
// panic!("{}", res);
res
} }

View File

@@ -4,6 +4,7 @@ pub struct RpcServerArgs {
command: Path, command: Path,
ctx: Expr, ctx: Expr,
status_fn: Option<Expr>, status_fn: Option<Expr>,
middleware: punctuated::Punctuated<Expr, token::Comma>,
} }
pub mod build; pub mod build;

View File

@@ -1,24 +1,45 @@
use syn::parse::{Parse, ParseStream}; use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use super::*; use super::*;
impl Parse for RpcServerArgs { impl Parse for RpcServerArgs {
fn parse(input: ParseStream) -> Result<Self> { fn parse(input: ParseStream) -> Result<Self> {
let command = input.parse()?; let args;
let _: token::Comma = input.parse()?; braced!(args in input);
let ctx = input.parse()?; let mut command = None;
if !input.is_empty() { let mut ctx = None;
let _: token::Comma = input.parse()?; let mut status_fn = None;
let mut middleware = Punctuated::new();
while !args.is_empty() {
let arg_name: syn::Ident = args.parse()?;
let _: token::Colon = args.parse()?;
match arg_name.to_string().as_str() {
"command" => {
command = Some(args.parse()?);
}
"context" => {
ctx = Some(args.parse()?);
}
"status" => {
status_fn = Some(args.parse()?);
}
"middleware" => {
let middlewares;
bracketed!(middlewares in args);
middleware = middlewares.parse_terminated(Expr::parse)?;
}
_ => return Err(Error::new(arg_name.span(), "unknown argument")),
}
if !args.is_empty() {
let _: token::Comma = args.parse()?;
}
} }
let status_fn = if !input.is_empty() {
Some(input.parse()?)
} else {
None
};
Ok(RpcServerArgs { Ok(RpcServerArgs {
command, command: command.expect("`command` is required"),
ctx, ctx: ctx.expect("`context` is required"),
status_fn, status_fn,
middleware,
}) })
} }
} }

View File

@@ -11,14 +11,15 @@ default = ["cbor"]
[dependencies] [dependencies]
clap = "2.33.3" clap = "2.33.3"
hyper = { version = "0.14.5", features = ["server", "http1", "http2", "tcp", "stream", "client"] } futures = "0.3.15"
hyper = { version="0.14.5", features=["server", "http1", "http2", "tcp", "stream", "client"] }
lazy_static = "1.4.0" lazy_static = "1.4.0"
reqwest = { version = "0.11.2" } reqwest = { version="0.11.2" }
rpc-toolkit-macro = { path = "../rpc-toolkit-macro" } rpc-toolkit-macro = { path="../rpc-toolkit-macro" }
serde = { version = "1.0.125", features = ["derive"] } serde = { version="1.0.125", features=["derive"] }
serde_cbor = { version = "0.11.1", optional = true } serde_cbor = { version="0.11.1", optional=true }
serde_json = "1.0.64" serde_json = "1.0.64"
thiserror = "1.0.24" thiserror = "1.0.24"
tokio = { version = "1.4.0", features = ["full"] } tokio = { version="1.4.0", features=["full"] }
url = "2.2.1" url = "2.2.1"
yajrc = { version = "*", path = "../../yajrc" } yajrc = { version="*", path="../../yajrc" }

View File

@@ -1,4 +1,8 @@
use std::future::Future;
use futures::future::BoxFuture;
use hyper::body::Buf; use hyper::body::Buf;
use hyper::http::Error as HttpError;
use hyper::server::conn::AddrIncoming; use hyper::server::conn::AddrIncoming;
use hyper::server::{Builder, Server}; use hyper::server::{Builder, Server};
use hyper::{Body, Request, Response, StatusCode}; use hyper::{Body, Request, Response, StatusCode};
@@ -58,7 +62,7 @@ pub fn to_response<F: Fn(i32) -> StatusCode>(
req: &Request<Body>, req: &Request<Body>,
res: Result<(Option<Id>, Result<Value, RpcError>), RpcError>, res: Result<(Option<Id>, Result<Value, RpcError>), RpcError>,
status_code_fn: F, status_code_fn: F,
) -> Result<Response<Body>, hyper::http::Error> { ) -> Result<Response<Body>, HttpError> {
let rpc_res: RpcResponse = match res { let rpc_res: RpcResponse = match res {
Ok((id, result)) => RpcResponse { id, result }, Ok((id, result)) => RpcResponse { id, result },
Err(e) => e.into(), Err(e) => e.into(),
@@ -94,3 +98,40 @@ pub fn to_response<F: Fn(i32) -> StatusCode>(
}); });
res.body(Body::from(body)) res.body(Body::from(body))
} }
pub type DynMiddleware<'a, 'b, 'c, Params> = Box<
dyn FnOnce(
&'a mut Request<Body>,
) -> BoxFuture<
'a,
Result<Result<DynMiddlewareStage2<'b, 'c, Params>, Response<Body>>, HttpError>,
> + Send
+ Sync,
>;
pub type DynMiddlewareStage2<'a, 'b, Params> = Box<
dyn FnOnce(
&'a mut RpcRequest<GenericRpcMethod<String, Params>>,
)
-> BoxFuture<'a, Result<Result<DynMiddlewareStage3<'b>, Response<Body>>, HttpError>>
+ Send
+ Sync,
>;
pub type DynMiddlewareStage3<'a> =
Box<dyn FnOnce(&'a mut Response<Body>) -> BoxFuture<'a, Result<(), HttpError>> + Send + Sync>;
pub fn constrain_middleware<
'a,
'b,
'c,
Params: for<'de> Deserialize<'de> + 'static,
ReqFn: Fn(&'a mut Request<Body>) -> ReqFut,
ReqFut: Future<Output = Result<Result<RpcReqFn, Response<Body>>, HttpError>> + 'a,
RpcReqFn: FnOnce(&'b mut RpcRequest<GenericRpcMethod<String, Params>>) -> RpcReqFut,
RpcReqFut: Future<Output = Result<Result<ResFn, Response<Body>>, HttpError>> + 'b,
ResFn: FnOnce(&'c mut Response<Body>) -> ResFut,
ResFut: Future<Output = Result<(), HttpError>> + 'c,
>(
f: ReqFn,
) -> ReqFn {
f
}

View File

@@ -2,7 +2,12 @@ use std::fmt::Display;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use futures::FutureExt;
use hyper::Request;
use rpc_toolkit::clap::Arg; use rpc_toolkit::clap::Arg;
use rpc_toolkit::hyper::http::Error as HttpError;
use rpc_toolkit::hyper::{Body, Response};
use rpc_toolkit::rpc_server_helpers::{DynMiddlewareStage2, DynMiddlewareStage3};
use rpc_toolkit::serde::{Deserialize, Serialize}; use rpc_toolkit::serde::{Deserialize, Serialize};
use rpc_toolkit::url::Host; use rpc_toolkit::url::Host;
use rpc_toolkit::yajrc::RpcError; use rpc_toolkit::yajrc::RpcError;
@@ -84,15 +89,46 @@ fn dothething2<U: Serialize + for<'a> Deserialize<'a> + FromStr<Err = E>, E: Dis
)) ))
} }
async fn cors<'a, 'b, Params: for<'de> Deserialize<'de> + 'static>(
req: &mut Request<Body>,
) -> Result<Result<DynMiddlewareStage2<'a, 'b, Params>, Response<Body>>, HttpError> {
if req.method() == hyper::Method::OPTIONS {
Ok(Err(Response::builder()
.header("Access-Control-Allow-Origin", "*")
.body(Body::empty())?))
} else {
Ok(Ok(Box::new(|_req| {
async move {
let res: DynMiddlewareStage3 = Box::new(|res| {
async move {
res.headers_mut()
.insert("Access-Control-Allow-Origin", "*".parse()?);
Ok::<_, HttpError>(())
}
.boxed()
});
Ok::<_, HttpError>(Ok(res))
}
.boxed()
})))
}
}
#[tokio::test] #[tokio::test]
async fn test() { async fn test_rpc() {
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
let seed = Arc::new(ConfigSeed { let seed = Arc::new(ConfigSeed {
host: Host::parse("localhost").unwrap(), host: Host::parse("localhost").unwrap(),
port: 8000, port: 8000,
}); });
let server = rpc_server!(dothething::<String, _>, AppState { seed, data: () }); let server = rpc_server!({
command: dothething::<String, _>,
context: AppState { seed, data: () },
middleware: [
cors,
],
});
let handle = tokio::spawn(server); let handle = tokio::spawn(server);
let mut cmd = tokio::process::Command::new("cargo") let mut cmd = tokio::process::Command::new("cargo")
.arg("test") .arg("test")