access request and response headers from rpc body

This commit is contained in:
Aiden McClelland
2021-07-28 16:44:15 -06:00
parent da0d08fa35
commit 1e9ded9a31
7 changed files with 153 additions and 53 deletions

View File

@@ -444,7 +444,9 @@ fn rpc_handler(
quote! { args.#field_name } quote! { args.#field_name }
} }
ParamType::Context(_) => quote! { ctx }, ParamType::Context(_) => quote! { ctx },
_ => unreachable!(), ParamType::Request => quote! { request },
ParamType::Response => quote! { response },
ParamType::None => unreachable!(),
}); });
match opt { match opt {
Options::Leaf(opt) if matches!(opt.exec_ctx, ExecutionContext::CliOnly(_)) => quote! { Options::Leaf(opt) if matches!(opt.exec_ctx, ExecutionContext::CliOnly(_)) => quote! {
@@ -452,6 +454,8 @@ fn rpc_handler(
pub async fn rpc_handler#fn_generics( pub async fn rpc_handler#fn_generics(
_ctx: #ctx_ty, _ctx: #ctx_ty,
_request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
_response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
method: &str, method: &str,
_args: Params#param_ty_generics, _args: Params#param_ty_generics,
) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> { ) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> {
@@ -480,6 +484,8 @@ fn rpc_handler(
pub async fn rpc_handler#fn_generics( pub async fn rpc_handler#fn_generics(
ctx: #ctx_ty, ctx: #ctx_ty,
request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
method: &str, method: &str,
args: Params#param_ty_generics, args: Params#param_ty_generics,
) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> { ) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> {
@@ -515,7 +521,7 @@ fn rpc_handler(
), ),
}; };
quote_spanned!{ subcommand.span() => quote_spanned!{ subcommand.span() =>
[#subcommand::NAME, rest] => #subcommand::#rpc_handler(ctx, rest, ::rpc_toolkit::command_helpers::prelude::from_value(args.rest)?).await [#subcommand::NAME, rest] => #subcommand::#rpc_handler(ctx, request, response, rest, ::rpc_toolkit::command_helpers::prelude::from_value(args.rest)?).await
} }
}); });
let subcmd_impl = quote! { let subcmd_impl = quote! {
@@ -550,6 +556,8 @@ fn rpc_handler(
pub async fn rpc_handler#fn_generics( pub async fn rpc_handler#fn_generics(
ctx: #ctx_ty, ctx: #ctx_ty,
request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
method: &str, method: &str,
args: Params#param_ty_generics, args: Params#param_ty_generics,
) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> { ) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> {
@@ -569,6 +577,8 @@ fn rpc_handler(
pub async fn rpc_handler#fn_generics( pub async fn rpc_handler#fn_generics(
ctx: #ctx_ty, ctx: #ctx_ty,
request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
method: &str, method: &str,
args: Params#param_ty_generics, args: Params#param_ty_generics,
) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> { ) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> {
@@ -618,7 +628,9 @@ fn cli_handler(
quote! { params.#field_name.clone() } quote! { params.#field_name.clone() }
} }
ParamType::Context(_) => quote! { ctx }, ParamType::Context(_) => quote! { ctx },
_ => unreachable!(), ParamType::Request => quote! { request },
ParamType::Response => quote! { response },
ParamType::None => unreachable!(),
}); });
let mut param_generics_filter = GenericFilter::new(fn_generics); let mut param_generics_filter = GenericFilter::new(fn_generics);
for param in params { for param in params {

View File

@@ -90,4 +90,6 @@ pub enum ParamType {
None, None,
Arg(ArgOptions), Arg(ArgOptions),
Context(Type), Context(Type),
Request,
Response,
} }

View File

@@ -666,6 +666,16 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result<Vec<ParamType>> {
attr.span(), attr.span(),
"`arg` and `context` are mutually exclusive", "`arg` and `context` are mutually exclusive",
)); ));
} else if matches!(ty, ParamType::Request) {
return Err(Error::new(
attr.span(),
"`arg` and `request` are mutually exclusive",
));
} else if matches!(ty, ParamType::Response) {
return Err(Error::new(
attr.span(),
"`arg` and `response` are mutually exclusive",
));
} }
} else if param.attrs[i].path.is_ident("context") { } else if param.attrs[i].path.is_ident("context") {
let attr = param.attrs.remove(i); let attr = param.attrs.remove(i);
@@ -681,6 +691,66 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result<Vec<ParamType>> {
attr.span(), attr.span(),
"`arg` and `context` are mutually exclusive", "`arg` and `context` are mutually exclusive",
)); ));
} else if matches!(ty, ParamType::Request) {
return Err(Error::new(
attr.span(),
"`context` and `request` are mutually exclusive",
));
} else if matches!(ty, ParamType::Response) {
return Err(Error::new(
attr.span(),
"`context` and `response` are mutually exclusive",
));
}
} else if param.attrs[i].path.is_ident("request") {
let attr = param.attrs.remove(i);
if matches!(ty, ParamType::None) {
ty = ParamType::Request;
} else if matches!(ty, ParamType::Request) {
return Err(Error::new(
attr.span(),
"`request` attribute may only be specified once",
));
} else if matches!(ty, ParamType::Arg(_)) {
return Err(Error::new(
attr.span(),
"`arg` and `request` are mutually exclusive",
));
} else if matches!(ty, ParamType::Context(_)) {
return Err(Error::new(
attr.span(),
"`context` and `request` are mutually exclusive",
));
} else if matches!(ty, ParamType::Response) {
return Err(Error::new(
attr.span(),
"`request` and `response` are mutually exclusive",
));
}
} else if param.attrs[i].path.is_ident("response") {
let attr = param.attrs.remove(i);
if matches!(ty, ParamType::None) {
ty = ParamType::Response;
} else if matches!(ty, ParamType::Response) {
return Err(Error::new(
attr.span(),
"`response` attribute may only be specified once",
));
} else if matches!(ty, ParamType::Arg(_)) {
return Err(Error::new(
attr.span(),
"`arg` and `response` are mutually exclusive",
));
} else if matches!(ty, ParamType::Context(_)) {
return Err(Error::new(
attr.span(),
"`context` and `response` are mutually exclusive",
));
} else if matches!(ty, ParamType::Request) {
return Err(Error::new(
attr.span(),
"`request` and `response` are mutually exclusive",
));
} }
} else { } else {
i += 1; i += 1;

View File

@@ -54,29 +54,30 @@ pub fn build(args: RpcServerArgs) -> TokenStream {
Err(res) => return Ok(res), Err(res) => return Ok(res),
}; };
)* )*
let rpc_req = ::rpc_toolkit::rpc_server_helpers::make_request(&mut req).await; let (mut req_parts, req_body) = req.into_parts();
let (mut res_parts, _) = ::rpc_toolkit::hyper::Response::new(()).into_parts();
let rpc_req = ::rpc_toolkit::rpc_server_helpers::make_request(&req_parts, req_body).await;
match rpc_req { match rpc_req {
Ok(mut rpc_req) => { Ok(mut rpc_req) => {
#( #(
let #middleware_name_post = match #middleware_name_pre2(&mut rpc_req).await? { let #middleware_name_post = match #middleware_name_pre2(&mut req_parts, &mut rpc_req).await? {
Ok(a) => a, Ok(a) => a,
Err(res) => return Ok(res), Err(res) => return Ok(res),
}; };
)* )*
let mut rpc_res = #command( let mut rpc_res = match ::rpc_toolkit::serde_json::from_value(::rpc_toolkit::serde_json::Value::Object(rpc_req.params)) {
ctx, Ok(params) => #command(ctx, &req_parts, &mut res_parts, ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), params).await,
::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), Err(e) => Err(e.into())
rpc_req.params, };
)
.await;
#( #(
let #middleware_name = match #middleware_name_post_inv(&mut rpc_res).await? { let #middleware_name = match #middleware_name_post_inv(&mut res_parts, &mut rpc_res).await? {
Ok(a) => a, Ok(a) => a,
Err(res) => return Ok(res), Err(res) => return Ok(res),
}; };
)* )*
let mut res = ::rpc_toolkit::rpc_server_helpers::to_response( let mut res = ::rpc_toolkit::rpc_server_helpers::to_response(
&req, &req_parts.headers,
res_parts,
Ok(( Ok((
rpc_req.id, rpc_req.id,
rpc_res, rpc_res,
@@ -89,7 +90,8 @@ pub fn build(args: RpcServerArgs) -> TokenStream {
Ok::<_, ::rpc_toolkit::hyper::http::Error>(res) Ok::<_, ::rpc_toolkit::hyper::http::Error>(res)
} }
Err(e) => ::rpc_toolkit::rpc_server_helpers::to_response( Err(e) => ::rpc_toolkit::rpc_server_helpers::to_response(
&req, &req_parts.headers,
res_parts,
Err(e), Err(e),
status_fn, status_fn,
), ),

View File

@@ -16,6 +16,8 @@ pub mod prelude {
pub use std::marker::PhantomData; pub use std::marker::PhantomData;
pub use clap::{App, AppSettings, Arg, ArgMatches}; pub use clap::{App, AppSettings, Arg, ArgMatches};
pub use hyper::http::request::Parts as RequestParts;
pub use hyper::http::response::Parts as ResponseParts;
pub use serde::{Deserialize, Serialize}; pub use serde::{Deserialize, Serialize};
pub use serde_json::{from_value, to_value, Value}; pub use serde_json::{from_value, to_value, Value};
pub use tokio::runtime::Runtime; pub use tokio::runtime::Runtime;

View File

@@ -3,13 +3,15 @@ use std::future::Future;
use futures::future::BoxFuture; use futures::future::BoxFuture;
use futures::FutureExt; use futures::FutureExt;
use hyper::body::Buf; use hyper::body::Buf;
use hyper::header::HeaderValue;
use hyper::http::request::Parts as RequestParts;
use hyper::http::response::Parts as ResponseParts;
use hyper::http::Error as HttpError; use hyper::http::Error as HttpError;
use hyper::server::conn::AddrIncoming; use hyper::server::conn::AddrIncoming;
use hyper::server::{Builder, Server}; use hyper::server::{Builder, Server};
use hyper::{Body, Request, Response, StatusCode}; use hyper::{Body, HeaderMap, Request, Response, StatusCode};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use serde::Deserialize; use serde_json::{Map, Value};
use serde_json::Value;
use url::Host; use url::Host;
use yajrc::{AnyRpcMethod, GenericRpcMethod, Id, RpcError, RpcRequest, RpcResponse}; use yajrc::{AnyRpcMethod, GenericRpcMethod, Id, RpcError, RpcRequest, RpcResponse};
@@ -33,16 +35,15 @@ pub fn make_builder<Ctx: Context>(ctx: &Ctx) -> Builder<AddrIncoming> {
Server::bind(&addr) Server::bind(&addr)
} }
pub async fn make_request<Params: for<'de> Deserialize<'de> + 'static>( pub async fn make_request(
req: &mut Request<Body>, req_parts: &RequestParts,
) -> Result<RpcRequest<GenericRpcMethod<String, Params>>, RpcError> { req_body: Body,
let body = hyper::body::aggregate(std::mem::replace(req.body_mut(), Body::empty())) ) -> Result<RpcRequest<GenericRpcMethod<String, Map<String, Value>>>, RpcError> {
.await? let body = hyper::body::aggregate(req_body).await?.reader();
.reader(); let rpc_req: RpcRequest<GenericRpcMethod<String, Map<String, Value>>>;
let rpc_req: RpcRequest<GenericRpcMethod<String, Params>>;
#[cfg(feature = "cbor")] #[cfg(feature = "cbor")]
if req if req_parts
.headers() .headers
.get("content-type") .get("content-type")
.and_then(|h| h.to_str().ok()) .and_then(|h| h.to_str().ok())
== Some("application/cbor") == Some("application/cbor")
@@ -60,7 +61,8 @@ pub async fn make_request<Params: for<'de> Deserialize<'de> + 'static>(
} }
pub fn to_response<F: Fn(i32) -> StatusCode>( pub fn to_response<F: Fn(i32) -> StatusCode>(
req: &Request<Body>, req_headers: &HeaderMap<HeaderValue>,
mut res_parts: ResponseParts,
res: Result<(Option<Id>, Result<Value, RpcError>), RpcError>, res: Result<(Option<Id>, Result<Value, RpcError>), RpcError>,
status_code_fn: F, status_code_fn: F,
) -> Result<Response<Body>, HttpError> { ) -> Result<Response<Body>, HttpError> {
@@ -69,10 +71,8 @@ pub fn to_response<F: Fn(i32) -> StatusCode>(
Err(e) => e.into(), Err(e) => e.into(),
}; };
let body; let body;
let mut res = Response::builder();
#[cfg(feature = "cbor")] #[cfg(feature = "cbor")]
if req if req_headers
.headers()
.get("accept") .get("accept")
.and_then(|h| h.to_str().ok()) .and_then(|h| h.to_str().ok())
.iter() .iter()
@@ -81,54 +81,64 @@ pub fn to_response<F: Fn(i32) -> StatusCode>(
.any(|s| s == "application/cbor") .any(|s| s == "application/cbor")
// prefer cbor if accepted // prefer cbor if accepted
{ {
res = res.header("content-type", "application/cbor"); res_parts
.headers
.insert("content-type", HeaderValue::from_static("application/cbor"));
body = serde_cbor::to_vec(&rpc_res).unwrap_or_else(|_| CBOR_INTERNAL_ERROR.clone()); body = serde_cbor::to_vec(&rpc_res).unwrap_or_else(|_| CBOR_INTERNAL_ERROR.clone());
} else { } else {
res = res.header("content-type", "application/json"); res_parts
.headers
.insert("content-type", HeaderValue::from_static("application/json"));
body = serde_json::to_vec(&rpc_res).unwrap_or_else(|_| JSON_INTERNAL_ERROR.clone()); body = serde_json::to_vec(&rpc_res).unwrap_or_else(|_| JSON_INTERNAL_ERROR.clone());
} }
#[cfg(not(feature = "cbor"))] #[cfg(not(feature = "cbor"))]
{ {
res.header("content-type", "application/json"); res_parts
.headers
.insert("content-type", HeaderValue::from_static("application/json"));
body = serde_json::to_vec(&rpc_res).unwrap_or_else(|_| JSON_INTERNAL_ERROR.clone()); body = serde_json::to_vec(&rpc_res).unwrap_or_else(|_| JSON_INTERNAL_ERROR.clone());
} }
res = res.header("content-length", body.len()); res_parts.headers.insert(
res = res.status(match &rpc_res.result { "content-length",
HeaderValue::from_str(&format!("{}", body.len()))?,
);
res_parts.status = match &rpc_res.result {
Ok(_) => StatusCode::OK, Ok(_) => StatusCode::OK,
Err(e) => status_code_fn(e.code), Err(e) => status_code_fn(e.code),
}); };
res.body(Body::from(body)) Ok(Response::from_parts(res_parts, body.into()))
} }
// &mut Request<Body> -> Result<Result<Future<&mut RpcRequest<...> -> Future<Result<Result<&mut Response<Body> -> Future<Result<(), HttpError>>, Response<Body>>, HttpError>>>, Response<Body>>, HttpError> // &mut Request<Body> -> Result<Result<Future<&mut RpcRequest<...> -> Future<Result<Result<&mut Response<Body> -> Future<Result<(), HttpError>>, Response<Body>>, HttpError>>>, Response<Body>>, HttpError>
pub type DynMiddleware<Params, Metadata> = Box< pub type DynMiddleware<Metadata> = Box<
dyn for<'a> FnOnce( dyn for<'a> FnOnce(
&'a mut Request<Body>, &'a mut Request<Body>,
Metadata, Metadata,
) -> BoxFuture< ) -> BoxFuture<
'a, 'a,
Result<Result<DynMiddlewareStage2<Params>, Response<Body>>, HttpError>, Result<Result<DynMiddlewareStage2, Response<Body>>, HttpError>,
> + Send > + Send
+ Sync, + Sync,
>; >;
pub fn noop<Params: for<'de> Deserialize<'de> + 'static, M: Metadata>() -> DynMiddleware<Params, M> pub fn noop<M: Metadata>() -> DynMiddleware<M> {
{
Box::new(|_, _| async { Ok(Ok(noop2())) }.boxed()) Box::new(|_, _| async { Ok(Ok(noop2())) }.boxed())
} }
pub type DynMiddlewareStage2<Params> = Box< pub type DynMiddlewareStage2 = Box<
dyn for<'a> FnOnce( dyn for<'a> FnOnce(
&'a mut RpcRequest<GenericRpcMethod<String, Params>>, &'a mut RequestParts,
&'a mut RpcRequest<GenericRpcMethod<String, Map<String, Value>>>,
) -> BoxFuture< ) -> BoxFuture<
'a, 'a,
Result<Result<DynMiddlewareStage3, Response<Body>>, HttpError>, Result<Result<DynMiddlewareStage3, Response<Body>>, HttpError>,
> + Send > + Send
+ Sync, + Sync,
>; >;
pub fn noop2<Params: for<'de> Deserialize<'de> + 'static>() -> DynMiddlewareStage2<Params> { pub fn noop2() -> DynMiddlewareStage2 {
Box::new(|_| async { Ok(Ok(noop3())) }.boxed()) Box::new(|_, _| async { Ok(Ok(noop3())) }.boxed())
} }
pub type DynMiddlewareStage3 = Box< pub type DynMiddlewareStage3 = Box<
dyn for<'a> FnOnce( dyn for<'a> FnOnce(
&'a mut ResponseParts,
&'a mut Result<Value, RpcError>, &'a mut Result<Value, RpcError>,
) -> BoxFuture< ) -> BoxFuture<
'a, 'a,
@@ -137,7 +147,7 @@ pub type DynMiddlewareStage3 = Box<
+ Sync, + Sync,
>; >;
pub fn noop3() -> DynMiddlewareStage3 { pub fn noop3() -> DynMiddlewareStage3 {
Box::new(|_| async { Ok(Ok(noop4())) }.boxed()) Box::new(|_, _| async { Ok(Ok(noop4())) }.boxed())
} }
pub type DynMiddlewareStage4 = Box< pub type DynMiddlewareStage4 = Box<
dyn for<'a> FnOnce(&'a mut Response<Body>) -> BoxFuture<'a, Result<(), HttpError>> dyn for<'a> FnOnce(&'a mut Response<Body>) -> BoxFuture<'a, Result<(), HttpError>>
@@ -153,13 +163,15 @@ pub fn constrain_middleware<
'b, 'b,
'c, 'c,
'd, 'd,
Params: for<'de> Deserialize<'de> + 'static,
M: Metadata, M: Metadata,
ReqFn: Fn(&'a mut Request<Body>, M) -> ReqFut, ReqFn: Fn(&'a mut Request<Body>, M) -> ReqFut,
ReqFut: Future<Output = Result<Result<RpcReqFn, Response<Body>>, HttpError>> + 'a, ReqFut: Future<Output = Result<Result<RpcReqFn, Response<Body>>, HttpError>> + 'a,
RpcReqFn: FnOnce(&'b mut RpcRequest<GenericRpcMethod<String, Params>>) -> RpcReqFut, RpcReqFn: FnOnce(
&'b mut RequestParts,
&'b mut RpcRequest<GenericRpcMethod<String, Map<String, Value>>>,
) -> RpcReqFut,
RpcReqFut: Future<Output = Result<Result<RpcResFn, Response<Body>>, HttpError>> + 'b, RpcReqFut: Future<Output = Result<Result<RpcResFn, Response<Body>>, HttpError>> + 'b,
RpcResFn: FnOnce(&'c mut Result<Value, RpcError>) -> RpcResFut, RpcResFn: FnOnce(&'c mut ResponseParts, &'c mut Result<Value, RpcError>) -> RpcResFut,
RpcResFut: Future<Output = Result<Result<ResFn, Response<Body>>, HttpError>> + 'c, RpcResFut: Future<Output = Result<Result<ResFn, Response<Body>>, HttpError>> + 'c,
ResFn: FnOnce(&'d mut Response<Body>) -> ResFut, ResFn: FnOnce(&'d mut Response<Body>) -> ResFut,
ResFut: Future<Output = Result<(), HttpError>> + 'd, ResFut: Future<Output = Result<(), HttpError>> + 'd,

View File

@@ -91,18 +91,18 @@ fn dothething2<U: Serialize + for<'a> Deserialize<'a> + FromStr<Err = E>, E: Dis
)) ))
} }
async fn cors<Params: for<'de> Deserialize<'de> + 'static, M: Metadata + 'static>( async fn cors<M: Metadata + 'static>(
req: &mut Request<Body>, req: &mut Request<Body>,
_: M, _: M,
) -> Result<Result<DynMiddlewareStage2<Params>, Response<Body>>, HttpError> { ) -> Result<Result<DynMiddlewareStage2, Response<Body>>, HttpError> {
if req.method() == hyper::Method::OPTIONS { if req.method() == hyper::Method::OPTIONS {
Ok(Err(Response::builder() Ok(Err(Response::builder()
.header("Access-Control-Allow-Origin", "*") .header("Access-Control-Allow-Origin", "*")
.body(Body::empty())?)) .body(Body::empty())?))
} else { } else {
Ok(Ok(Box::new(|_| { Ok(Ok(Box::new(|_, _| {
async move { async move {
let res: DynMiddlewareStage3 = Box::new(|_| { let res: DynMiddlewareStage3 = Box::new(|_, _| {
async move { async move {
let res: DynMiddlewareStage4 = Box::new(|res| { let res: DynMiddlewareStage4 = Box::new(|res| {
async move { async move {