access request and response headers from rpc body

This commit is contained in:
Aiden McClelland
2021-07-28 16:44:15 -06:00
parent da0d08fa35
commit 1e9ded9a31
7 changed files with 153 additions and 53 deletions

View File

@@ -444,7 +444,9 @@ fn rpc_handler(
quote! { args.#field_name }
}
ParamType::Context(_) => quote! { ctx },
_ => unreachable!(),
ParamType::Request => quote! { request },
ParamType::Response => quote! { response },
ParamType::None => unreachable!(),
});
match opt {
Options::Leaf(opt) if matches!(opt.exec_ctx, ExecutionContext::CliOnly(_)) => quote! {
@@ -452,6 +454,8 @@ fn rpc_handler(
pub async fn rpc_handler#fn_generics(
_ctx: #ctx_ty,
_request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
_response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
method: &str,
_args: Params#param_ty_generics,
) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> {
@@ -480,6 +484,8 @@ fn rpc_handler(
pub async fn rpc_handler#fn_generics(
ctx: #ctx_ty,
request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
method: &str,
args: Params#param_ty_generics,
) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> {
@@ -515,7 +521,7 @@ fn rpc_handler(
),
};
quote_spanned!{ subcommand.span() =>
[#subcommand::NAME, rest] => #subcommand::#rpc_handler(ctx, rest, ::rpc_toolkit::command_helpers::prelude::from_value(args.rest)?).await
[#subcommand::NAME, rest] => #subcommand::#rpc_handler(ctx, request, response, rest, ::rpc_toolkit::command_helpers::prelude::from_value(args.rest)?).await
}
});
let subcmd_impl = quote! {
@@ -550,6 +556,8 @@ fn rpc_handler(
pub async fn rpc_handler#fn_generics(
ctx: #ctx_ty,
request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
method: &str,
args: Params#param_ty_generics,
) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> {
@@ -569,6 +577,8 @@ fn rpc_handler(
pub async fn rpc_handler#fn_generics(
ctx: #ctx_ty,
request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
method: &str,
args: Params#param_ty_generics,
) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> {
@@ -618,7 +628,9 @@ fn cli_handler(
quote! { params.#field_name.clone() }
}
ParamType::Context(_) => quote! { ctx },
_ => unreachable!(),
ParamType::Request => quote! { request },
ParamType::Response => quote! { response },
ParamType::None => unreachable!(),
});
let mut param_generics_filter = GenericFilter::new(fn_generics);
for param in params {

View File

@@ -90,4 +90,6 @@ pub enum ParamType {
None,
Arg(ArgOptions),
Context(Type),
Request,
Response,
}

View File

@@ -666,6 +666,16 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result<Vec<ParamType>> {
attr.span(),
"`arg` and `context` are mutually exclusive",
));
} else if matches!(ty, ParamType::Request) {
return Err(Error::new(
attr.span(),
"`arg` and `request` are mutually exclusive",
));
} else if matches!(ty, ParamType::Response) {
return Err(Error::new(
attr.span(),
"`arg` and `response` are mutually exclusive",
));
}
} else if param.attrs[i].path.is_ident("context") {
let attr = param.attrs.remove(i);
@@ -681,6 +691,66 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result<Vec<ParamType>> {
attr.span(),
"`arg` and `context` are mutually exclusive",
));
} else if matches!(ty, ParamType::Request) {
return Err(Error::new(
attr.span(),
"`context` and `request` are mutually exclusive",
));
} else if matches!(ty, ParamType::Response) {
return Err(Error::new(
attr.span(),
"`context` and `response` are mutually exclusive",
));
}
} else if param.attrs[i].path.is_ident("request") {
let attr = param.attrs.remove(i);
if matches!(ty, ParamType::None) {
ty = ParamType::Request;
} else if matches!(ty, ParamType::Request) {
return Err(Error::new(
attr.span(),
"`request` attribute may only be specified once",
));
} else if matches!(ty, ParamType::Arg(_)) {
return Err(Error::new(
attr.span(),
"`arg` and `request` are mutually exclusive",
));
} else if matches!(ty, ParamType::Context(_)) {
return Err(Error::new(
attr.span(),
"`context` and `request` are mutually exclusive",
));
} else if matches!(ty, ParamType::Response) {
return Err(Error::new(
attr.span(),
"`request` and `response` are mutually exclusive",
));
}
} else if param.attrs[i].path.is_ident("response") {
let attr = param.attrs.remove(i);
if matches!(ty, ParamType::None) {
ty = ParamType::Response;
} else if matches!(ty, ParamType::Response) {
return Err(Error::new(
attr.span(),
"`response` attribute may only be specified once",
));
} else if matches!(ty, ParamType::Arg(_)) {
return Err(Error::new(
attr.span(),
"`arg` and `response` are mutually exclusive",
));
} else if matches!(ty, ParamType::Context(_)) {
return Err(Error::new(
attr.span(),
"`context` and `response` are mutually exclusive",
));
} else if matches!(ty, ParamType::Request) {
return Err(Error::new(
attr.span(),
"`request` and `response` are mutually exclusive",
));
}
} else {
i += 1;

View File

@@ -54,29 +54,30 @@ pub fn build(args: RpcServerArgs) -> TokenStream {
Err(res) => return Ok(res),
};
)*
let rpc_req = ::rpc_toolkit::rpc_server_helpers::make_request(&mut req).await;
let (mut req_parts, req_body) = req.into_parts();
let (mut res_parts, _) = ::rpc_toolkit::hyper::Response::new(()).into_parts();
let rpc_req = ::rpc_toolkit::rpc_server_helpers::make_request(&req_parts, req_body).await;
match rpc_req {
Ok(mut rpc_req) => {
#(
let #middleware_name_post = match #middleware_name_pre2(&mut rpc_req).await? {
let #middleware_name_post = match #middleware_name_pre2(&mut req_parts, &mut rpc_req).await? {
Ok(a) => a,
Err(res) => return Ok(res),
};
)*
let mut rpc_res = #command(
ctx,
::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method),
rpc_req.params,
)
.await;
let mut rpc_res = match ::rpc_toolkit::serde_json::from_value(::rpc_toolkit::serde_json::Value::Object(rpc_req.params)) {
Ok(params) => #command(ctx, &req_parts, &mut res_parts, ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), params).await,
Err(e) => Err(e.into())
};
#(
let #middleware_name = match #middleware_name_post_inv(&mut rpc_res).await? {
let #middleware_name = match #middleware_name_post_inv(&mut res_parts, &mut rpc_res).await? {
Ok(a) => a,
Err(res) => return Ok(res),
};
)*
let mut res = ::rpc_toolkit::rpc_server_helpers::to_response(
&req,
&req_parts.headers,
res_parts,
Ok((
rpc_req.id,
rpc_res,
@@ -89,7 +90,8 @@ pub fn build(args: RpcServerArgs) -> TokenStream {
Ok::<_, ::rpc_toolkit::hyper::http::Error>(res)
}
Err(e) => ::rpc_toolkit::rpc_server_helpers::to_response(
&req,
&req_parts.headers,
res_parts,
Err(e),
status_fn,
),

View File

@@ -16,6 +16,8 @@ pub mod prelude {
pub use std::marker::PhantomData;
pub use clap::{App, AppSettings, Arg, ArgMatches};
pub use hyper::http::request::Parts as RequestParts;
pub use hyper::http::response::Parts as ResponseParts;
pub use serde::{Deserialize, Serialize};
pub use serde_json::{from_value, to_value, Value};
pub use tokio::runtime::Runtime;

View File

@@ -3,13 +3,15 @@ use std::future::Future;
use futures::future::BoxFuture;
use futures::FutureExt;
use hyper::body::Buf;
use hyper::header::HeaderValue;
use hyper::http::request::Parts as RequestParts;
use hyper::http::response::Parts as ResponseParts;
use hyper::http::Error as HttpError;
use hyper::server::conn::AddrIncoming;
use hyper::server::{Builder, Server};
use hyper::{Body, Request, Response, StatusCode};
use hyper::{Body, HeaderMap, Request, Response, StatusCode};
use lazy_static::lazy_static;
use serde::Deserialize;
use serde_json::Value;
use serde_json::{Map, Value};
use url::Host;
use yajrc::{AnyRpcMethod, GenericRpcMethod, Id, RpcError, RpcRequest, RpcResponse};
@@ -33,16 +35,15 @@ pub fn make_builder<Ctx: Context>(ctx: &Ctx) -> Builder<AddrIncoming> {
Server::bind(&addr)
}
pub async fn make_request<Params: for<'de> Deserialize<'de> + 'static>(
req: &mut Request<Body>,
) -> Result<RpcRequest<GenericRpcMethod<String, Params>>, RpcError> {
let body = hyper::body::aggregate(std::mem::replace(req.body_mut(), Body::empty()))
.await?
.reader();
let rpc_req: RpcRequest<GenericRpcMethod<String, Params>>;
pub async fn make_request(
req_parts: &RequestParts,
req_body: Body,
) -> Result<RpcRequest<GenericRpcMethod<String, Map<String, Value>>>, RpcError> {
let body = hyper::body::aggregate(req_body).await?.reader();
let rpc_req: RpcRequest<GenericRpcMethod<String, Map<String, Value>>>;
#[cfg(feature = "cbor")]
if req
.headers()
if req_parts
.headers
.get("content-type")
.and_then(|h| h.to_str().ok())
== Some("application/cbor")
@@ -60,7 +61,8 @@ pub async fn make_request<Params: for<'de> Deserialize<'de> + 'static>(
}
pub fn to_response<F: Fn(i32) -> StatusCode>(
req: &Request<Body>,
req_headers: &HeaderMap<HeaderValue>,
mut res_parts: ResponseParts,
res: Result<(Option<Id>, Result<Value, RpcError>), RpcError>,
status_code_fn: F,
) -> Result<Response<Body>, HttpError> {
@@ -69,10 +71,8 @@ pub fn to_response<F: Fn(i32) -> StatusCode>(
Err(e) => e.into(),
};
let body;
let mut res = Response::builder();
#[cfg(feature = "cbor")]
if req
.headers()
if req_headers
.get("accept")
.and_then(|h| h.to_str().ok())
.iter()
@@ -81,54 +81,64 @@ pub fn to_response<F: Fn(i32) -> StatusCode>(
.any(|s| s == "application/cbor")
// prefer cbor if accepted
{
res = res.header("content-type", "application/cbor");
res_parts
.headers
.insert("content-type", HeaderValue::from_static("application/cbor"));
body = serde_cbor::to_vec(&rpc_res).unwrap_or_else(|_| CBOR_INTERNAL_ERROR.clone());
} else {
res = res.header("content-type", "application/json");
res_parts
.headers
.insert("content-type", HeaderValue::from_static("application/json"));
body = serde_json::to_vec(&rpc_res).unwrap_or_else(|_| JSON_INTERNAL_ERROR.clone());
}
#[cfg(not(feature = "cbor"))]
{
res.header("content-type", "application/json");
res_parts
.headers
.insert("content-type", HeaderValue::from_static("application/json"));
body = serde_json::to_vec(&rpc_res).unwrap_or_else(|_| JSON_INTERNAL_ERROR.clone());
}
res = res.header("content-length", body.len());
res = res.status(match &rpc_res.result {
res_parts.headers.insert(
"content-length",
HeaderValue::from_str(&format!("{}", body.len()))?,
);
res_parts.status = match &rpc_res.result {
Ok(_) => StatusCode::OK,
Err(e) => status_code_fn(e.code),
});
res.body(Body::from(body))
};
Ok(Response::from_parts(res_parts, body.into()))
}
// &mut Request<Body> -> Result<Result<Future<&mut RpcRequest<...> -> Future<Result<Result<&mut Response<Body> -> Future<Result<(), HttpError>>, Response<Body>>, HttpError>>>, Response<Body>>, HttpError>
pub type DynMiddleware<Params, Metadata> = Box<
pub type DynMiddleware<Metadata> = Box<
dyn for<'a> FnOnce(
&'a mut Request<Body>,
Metadata,
) -> BoxFuture<
'a,
Result<Result<DynMiddlewareStage2<Params>, Response<Body>>, HttpError>,
Result<Result<DynMiddlewareStage2, Response<Body>>, HttpError>,
> + Send
+ Sync,
>;
pub fn noop<Params: for<'de> Deserialize<'de> + 'static, M: Metadata>() -> DynMiddleware<Params, M>
{
pub fn noop<M: Metadata>() -> DynMiddleware<M> {
Box::new(|_, _| async { Ok(Ok(noop2())) }.boxed())
}
pub type DynMiddlewareStage2<Params> = Box<
pub type DynMiddlewareStage2 = Box<
dyn for<'a> FnOnce(
&'a mut RpcRequest<GenericRpcMethod<String, Params>>,
&'a mut RequestParts,
&'a mut RpcRequest<GenericRpcMethod<String, Map<String, Value>>>,
) -> BoxFuture<
'a,
Result<Result<DynMiddlewareStage3, Response<Body>>, HttpError>,
> + Send
+ Sync,
>;
pub fn noop2<Params: for<'de> Deserialize<'de> + 'static>() -> DynMiddlewareStage2<Params> {
Box::new(|_| async { Ok(Ok(noop3())) }.boxed())
pub fn noop2() -> DynMiddlewareStage2 {
Box::new(|_, _| async { Ok(Ok(noop3())) }.boxed())
}
pub type DynMiddlewareStage3 = Box<
dyn for<'a> FnOnce(
&'a mut ResponseParts,
&'a mut Result<Value, RpcError>,
) -> BoxFuture<
'a,
@@ -137,7 +147,7 @@ pub type DynMiddlewareStage3 = Box<
+ Sync,
>;
pub fn noop3() -> DynMiddlewareStage3 {
Box::new(|_| async { Ok(Ok(noop4())) }.boxed())
Box::new(|_, _| async { Ok(Ok(noop4())) }.boxed())
}
pub type DynMiddlewareStage4 = Box<
dyn for<'a> FnOnce(&'a mut Response<Body>) -> BoxFuture<'a, Result<(), HttpError>>
@@ -153,13 +163,15 @@ pub fn constrain_middleware<
'b,
'c,
'd,
Params: for<'de> Deserialize<'de> + 'static,
M: Metadata,
ReqFn: Fn(&'a mut Request<Body>, M) -> ReqFut,
ReqFut: Future<Output = Result<Result<RpcReqFn, Response<Body>>, HttpError>> + 'a,
RpcReqFn: FnOnce(&'b mut RpcRequest<GenericRpcMethod<String, Params>>) -> RpcReqFut,
RpcReqFn: FnOnce(
&'b mut RequestParts,
&'b mut RpcRequest<GenericRpcMethod<String, Map<String, Value>>>,
) -> RpcReqFut,
RpcReqFut: Future<Output = Result<Result<RpcResFn, Response<Body>>, HttpError>> + 'b,
RpcResFn: FnOnce(&'c mut Result<Value, RpcError>) -> RpcResFut,
RpcResFn: FnOnce(&'c mut ResponseParts, &'c mut Result<Value, RpcError>) -> RpcResFut,
RpcResFut: Future<Output = Result<Result<ResFn, Response<Body>>, HttpError>> + 'c,
ResFn: FnOnce(&'d mut Response<Body>) -> ResFut,
ResFut: Future<Output = Result<(), HttpError>> + 'd,

View File

@@ -91,18 +91,18 @@ fn dothething2<U: Serialize + for<'a> Deserialize<'a> + FromStr<Err = E>, E: Dis
))
}
async fn cors<Params: for<'de> Deserialize<'de> + 'static, M: Metadata + 'static>(
async fn cors<M: Metadata + 'static>(
req: &mut Request<Body>,
_: M,
) -> Result<Result<DynMiddlewareStage2<Params>, Response<Body>>, HttpError> {
) -> Result<Result<DynMiddlewareStage2, Response<Body>>, HttpError> {
if req.method() == hyper::Method::OPTIONS {
Ok(Err(Response::builder()
.header("Access-Control-Allow-Origin", "*")
.body(Body::empty())?))
} else {
Ok(Ok(Box::new(|_| {
Ok(Ok(Box::new(|_, _| {
async move {
let res: DynMiddlewareStage3 = Box::new(|_| {
let res: DynMiddlewareStage3 = Box::new(|_, _| {
async move {
let res: DynMiddlewareStage4 = Box::new(|res| {
async move {