mirror of
https://github.com/Start9Labs/rpc-toolkit.git
synced 2026-03-26 02:11:56 +00:00
access request and response headers from rpc body
This commit is contained in:
@@ -444,7 +444,9 @@ fn rpc_handler(
|
||||
quote! { args.#field_name }
|
||||
}
|
||||
ParamType::Context(_) => quote! { ctx },
|
||||
_ => unreachable!(),
|
||||
ParamType::Request => quote! { request },
|
||||
ParamType::Response => quote! { response },
|
||||
ParamType::None => unreachable!(),
|
||||
});
|
||||
match opt {
|
||||
Options::Leaf(opt) if matches!(opt.exec_ctx, ExecutionContext::CliOnly(_)) => quote! {
|
||||
@@ -452,6 +454,8 @@ fn rpc_handler(
|
||||
|
||||
pub async fn rpc_handler#fn_generics(
|
||||
_ctx: #ctx_ty,
|
||||
_request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
|
||||
_response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
|
||||
method: &str,
|
||||
_args: Params#param_ty_generics,
|
||||
) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> {
|
||||
@@ -480,6 +484,8 @@ fn rpc_handler(
|
||||
|
||||
pub async fn rpc_handler#fn_generics(
|
||||
ctx: #ctx_ty,
|
||||
request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
|
||||
response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
|
||||
method: &str,
|
||||
args: Params#param_ty_generics,
|
||||
) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> {
|
||||
@@ -515,7 +521,7 @@ fn rpc_handler(
|
||||
),
|
||||
};
|
||||
quote_spanned!{ subcommand.span() =>
|
||||
[#subcommand::NAME, rest] => #subcommand::#rpc_handler(ctx, rest, ::rpc_toolkit::command_helpers::prelude::from_value(args.rest)?).await
|
||||
[#subcommand::NAME, rest] => #subcommand::#rpc_handler(ctx, request, response, rest, ::rpc_toolkit::command_helpers::prelude::from_value(args.rest)?).await
|
||||
}
|
||||
});
|
||||
let subcmd_impl = quote! {
|
||||
@@ -550,6 +556,8 @@ fn rpc_handler(
|
||||
|
||||
pub async fn rpc_handler#fn_generics(
|
||||
ctx: #ctx_ty,
|
||||
request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
|
||||
response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
|
||||
method: &str,
|
||||
args: Params#param_ty_generics,
|
||||
) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> {
|
||||
@@ -569,6 +577,8 @@ fn rpc_handler(
|
||||
|
||||
pub async fn rpc_handler#fn_generics(
|
||||
ctx: #ctx_ty,
|
||||
request: &::rpc_toolkit::command_helpers::prelude::RequestParts,
|
||||
response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts,
|
||||
method: &str,
|
||||
args: Params#param_ty_generics,
|
||||
) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> {
|
||||
@@ -618,7 +628,9 @@ fn cli_handler(
|
||||
quote! { params.#field_name.clone() }
|
||||
}
|
||||
ParamType::Context(_) => quote! { ctx },
|
||||
_ => unreachable!(),
|
||||
ParamType::Request => quote! { request },
|
||||
ParamType::Response => quote! { response },
|
||||
ParamType::None => unreachable!(),
|
||||
});
|
||||
let mut param_generics_filter = GenericFilter::new(fn_generics);
|
||||
for param in params {
|
||||
|
||||
@@ -90,4 +90,6 @@ pub enum ParamType {
|
||||
None,
|
||||
Arg(ArgOptions),
|
||||
Context(Type),
|
||||
Request,
|
||||
Response,
|
||||
}
|
||||
|
||||
@@ -666,6 +666,16 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result<Vec<ParamType>> {
|
||||
attr.span(),
|
||||
"`arg` and `context` are mutually exclusive",
|
||||
));
|
||||
} else if matches!(ty, ParamType::Request) {
|
||||
return Err(Error::new(
|
||||
attr.span(),
|
||||
"`arg` and `request` are mutually exclusive",
|
||||
));
|
||||
} else if matches!(ty, ParamType::Response) {
|
||||
return Err(Error::new(
|
||||
attr.span(),
|
||||
"`arg` and `response` are mutually exclusive",
|
||||
));
|
||||
}
|
||||
} else if param.attrs[i].path.is_ident("context") {
|
||||
let attr = param.attrs.remove(i);
|
||||
@@ -681,6 +691,66 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result<Vec<ParamType>> {
|
||||
attr.span(),
|
||||
"`arg` and `context` are mutually exclusive",
|
||||
));
|
||||
} else if matches!(ty, ParamType::Request) {
|
||||
return Err(Error::new(
|
||||
attr.span(),
|
||||
"`context` and `request` are mutually exclusive",
|
||||
));
|
||||
} else if matches!(ty, ParamType::Response) {
|
||||
return Err(Error::new(
|
||||
attr.span(),
|
||||
"`context` and `response` are mutually exclusive",
|
||||
));
|
||||
}
|
||||
} else if param.attrs[i].path.is_ident("request") {
|
||||
let attr = param.attrs.remove(i);
|
||||
if matches!(ty, ParamType::None) {
|
||||
ty = ParamType::Request;
|
||||
} else if matches!(ty, ParamType::Request) {
|
||||
return Err(Error::new(
|
||||
attr.span(),
|
||||
"`request` attribute may only be specified once",
|
||||
));
|
||||
} else if matches!(ty, ParamType::Arg(_)) {
|
||||
return Err(Error::new(
|
||||
attr.span(),
|
||||
"`arg` and `request` are mutually exclusive",
|
||||
));
|
||||
} else if matches!(ty, ParamType::Context(_)) {
|
||||
return Err(Error::new(
|
||||
attr.span(),
|
||||
"`context` and `request` are mutually exclusive",
|
||||
));
|
||||
} else if matches!(ty, ParamType::Response) {
|
||||
return Err(Error::new(
|
||||
attr.span(),
|
||||
"`request` and `response` are mutually exclusive",
|
||||
));
|
||||
}
|
||||
} else if param.attrs[i].path.is_ident("response") {
|
||||
let attr = param.attrs.remove(i);
|
||||
if matches!(ty, ParamType::None) {
|
||||
ty = ParamType::Response;
|
||||
} else if matches!(ty, ParamType::Response) {
|
||||
return Err(Error::new(
|
||||
attr.span(),
|
||||
"`response` attribute may only be specified once",
|
||||
));
|
||||
} else if matches!(ty, ParamType::Arg(_)) {
|
||||
return Err(Error::new(
|
||||
attr.span(),
|
||||
"`arg` and `response` are mutually exclusive",
|
||||
));
|
||||
} else if matches!(ty, ParamType::Context(_)) {
|
||||
return Err(Error::new(
|
||||
attr.span(),
|
||||
"`context` and `response` are mutually exclusive",
|
||||
));
|
||||
} else if matches!(ty, ParamType::Request) {
|
||||
return Err(Error::new(
|
||||
attr.span(),
|
||||
"`request` and `response` are mutually exclusive",
|
||||
));
|
||||
}
|
||||
} else {
|
||||
i += 1;
|
||||
|
||||
@@ -54,29 +54,30 @@ pub fn build(args: RpcServerArgs) -> TokenStream {
|
||||
Err(res) => return Ok(res),
|
||||
};
|
||||
)*
|
||||
let rpc_req = ::rpc_toolkit::rpc_server_helpers::make_request(&mut req).await;
|
||||
let (mut req_parts, req_body) = req.into_parts();
|
||||
let (mut res_parts, _) = ::rpc_toolkit::hyper::Response::new(()).into_parts();
|
||||
let rpc_req = ::rpc_toolkit::rpc_server_helpers::make_request(&req_parts, req_body).await;
|
||||
match rpc_req {
|
||||
Ok(mut rpc_req) => {
|
||||
#(
|
||||
let #middleware_name_post = match #middleware_name_pre2(&mut rpc_req).await? {
|
||||
let #middleware_name_post = match #middleware_name_pre2(&mut req_parts, &mut rpc_req).await? {
|
||||
Ok(a) => a,
|
||||
Err(res) => return Ok(res),
|
||||
};
|
||||
)*
|
||||
let mut rpc_res = #command(
|
||||
ctx,
|
||||
::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method),
|
||||
rpc_req.params,
|
||||
)
|
||||
.await;
|
||||
let mut rpc_res = match ::rpc_toolkit::serde_json::from_value(::rpc_toolkit::serde_json::Value::Object(rpc_req.params)) {
|
||||
Ok(params) => #command(ctx, &req_parts, &mut res_parts, ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), params).await,
|
||||
Err(e) => Err(e.into())
|
||||
};
|
||||
#(
|
||||
let #middleware_name = match #middleware_name_post_inv(&mut rpc_res).await? {
|
||||
let #middleware_name = match #middleware_name_post_inv(&mut res_parts, &mut rpc_res).await? {
|
||||
Ok(a) => a,
|
||||
Err(res) => return Ok(res),
|
||||
};
|
||||
)*
|
||||
let mut res = ::rpc_toolkit::rpc_server_helpers::to_response(
|
||||
&req,
|
||||
&req_parts.headers,
|
||||
res_parts,
|
||||
Ok((
|
||||
rpc_req.id,
|
||||
rpc_res,
|
||||
@@ -89,7 +90,8 @@ pub fn build(args: RpcServerArgs) -> TokenStream {
|
||||
Ok::<_, ::rpc_toolkit::hyper::http::Error>(res)
|
||||
}
|
||||
Err(e) => ::rpc_toolkit::rpc_server_helpers::to_response(
|
||||
&req,
|
||||
&req_parts.headers,
|
||||
res_parts,
|
||||
Err(e),
|
||||
status_fn,
|
||||
),
|
||||
|
||||
@@ -16,6 +16,8 @@ pub mod prelude {
|
||||
pub use std::marker::PhantomData;
|
||||
|
||||
pub use clap::{App, AppSettings, Arg, ArgMatches};
|
||||
pub use hyper::http::request::Parts as RequestParts;
|
||||
pub use hyper::http::response::Parts as ResponseParts;
|
||||
pub use serde::{Deserialize, Serialize};
|
||||
pub use serde_json::{from_value, to_value, Value};
|
||||
pub use tokio::runtime::Runtime;
|
||||
|
||||
@@ -3,13 +3,15 @@ use std::future::Future;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::FutureExt;
|
||||
use hyper::body::Buf;
|
||||
use hyper::header::HeaderValue;
|
||||
use hyper::http::request::Parts as RequestParts;
|
||||
use hyper::http::response::Parts as ResponseParts;
|
||||
use hyper::http::Error as HttpError;
|
||||
use hyper::server::conn::AddrIncoming;
|
||||
use hyper::server::{Builder, Server};
|
||||
use hyper::{Body, Request, Response, StatusCode};
|
||||
use hyper::{Body, HeaderMap, Request, Response, StatusCode};
|
||||
use lazy_static::lazy_static;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::{Map, Value};
|
||||
use url::Host;
|
||||
use yajrc::{AnyRpcMethod, GenericRpcMethod, Id, RpcError, RpcRequest, RpcResponse};
|
||||
|
||||
@@ -33,16 +35,15 @@ pub fn make_builder<Ctx: Context>(ctx: &Ctx) -> Builder<AddrIncoming> {
|
||||
Server::bind(&addr)
|
||||
}
|
||||
|
||||
pub async fn make_request<Params: for<'de> Deserialize<'de> + 'static>(
|
||||
req: &mut Request<Body>,
|
||||
) -> Result<RpcRequest<GenericRpcMethod<String, Params>>, RpcError> {
|
||||
let body = hyper::body::aggregate(std::mem::replace(req.body_mut(), Body::empty()))
|
||||
.await?
|
||||
.reader();
|
||||
let rpc_req: RpcRequest<GenericRpcMethod<String, Params>>;
|
||||
pub async fn make_request(
|
||||
req_parts: &RequestParts,
|
||||
req_body: Body,
|
||||
) -> Result<RpcRequest<GenericRpcMethod<String, Map<String, Value>>>, RpcError> {
|
||||
let body = hyper::body::aggregate(req_body).await?.reader();
|
||||
let rpc_req: RpcRequest<GenericRpcMethod<String, Map<String, Value>>>;
|
||||
#[cfg(feature = "cbor")]
|
||||
if req
|
||||
.headers()
|
||||
if req_parts
|
||||
.headers
|
||||
.get("content-type")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
== Some("application/cbor")
|
||||
@@ -60,7 +61,8 @@ pub async fn make_request<Params: for<'de> Deserialize<'de> + 'static>(
|
||||
}
|
||||
|
||||
pub fn to_response<F: Fn(i32) -> StatusCode>(
|
||||
req: &Request<Body>,
|
||||
req_headers: &HeaderMap<HeaderValue>,
|
||||
mut res_parts: ResponseParts,
|
||||
res: Result<(Option<Id>, Result<Value, RpcError>), RpcError>,
|
||||
status_code_fn: F,
|
||||
) -> Result<Response<Body>, HttpError> {
|
||||
@@ -69,10 +71,8 @@ pub fn to_response<F: Fn(i32) -> StatusCode>(
|
||||
Err(e) => e.into(),
|
||||
};
|
||||
let body;
|
||||
let mut res = Response::builder();
|
||||
#[cfg(feature = "cbor")]
|
||||
if req
|
||||
.headers()
|
||||
if req_headers
|
||||
.get("accept")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.iter()
|
||||
@@ -81,54 +81,64 @@ pub fn to_response<F: Fn(i32) -> StatusCode>(
|
||||
.any(|s| s == "application/cbor")
|
||||
// prefer cbor if accepted
|
||||
{
|
||||
res = res.header("content-type", "application/cbor");
|
||||
res_parts
|
||||
.headers
|
||||
.insert("content-type", HeaderValue::from_static("application/cbor"));
|
||||
body = serde_cbor::to_vec(&rpc_res).unwrap_or_else(|_| CBOR_INTERNAL_ERROR.clone());
|
||||
} else {
|
||||
res = res.header("content-type", "application/json");
|
||||
res_parts
|
||||
.headers
|
||||
.insert("content-type", HeaderValue::from_static("application/json"));
|
||||
body = serde_json::to_vec(&rpc_res).unwrap_or_else(|_| JSON_INTERNAL_ERROR.clone());
|
||||
}
|
||||
#[cfg(not(feature = "cbor"))]
|
||||
{
|
||||
res.header("content-type", "application/json");
|
||||
res_parts
|
||||
.headers
|
||||
.insert("content-type", HeaderValue::from_static("application/json"));
|
||||
body = serde_json::to_vec(&rpc_res).unwrap_or_else(|_| JSON_INTERNAL_ERROR.clone());
|
||||
}
|
||||
res = res.header("content-length", body.len());
|
||||
res = res.status(match &rpc_res.result {
|
||||
res_parts.headers.insert(
|
||||
"content-length",
|
||||
HeaderValue::from_str(&format!("{}", body.len()))?,
|
||||
);
|
||||
res_parts.status = match &rpc_res.result {
|
||||
Ok(_) => StatusCode::OK,
|
||||
Err(e) => status_code_fn(e.code),
|
||||
});
|
||||
res.body(Body::from(body))
|
||||
};
|
||||
Ok(Response::from_parts(res_parts, body.into()))
|
||||
}
|
||||
|
||||
// &mut Request<Body> -> Result<Result<Future<&mut RpcRequest<...> -> Future<Result<Result<&mut Response<Body> -> Future<Result<(), HttpError>>, Response<Body>>, HttpError>>>, Response<Body>>, HttpError>
|
||||
pub type DynMiddleware<Params, Metadata> = Box<
|
||||
pub type DynMiddleware<Metadata> = Box<
|
||||
dyn for<'a> FnOnce(
|
||||
&'a mut Request<Body>,
|
||||
Metadata,
|
||||
) -> BoxFuture<
|
||||
'a,
|
||||
Result<Result<DynMiddlewareStage2<Params>, Response<Body>>, HttpError>,
|
||||
Result<Result<DynMiddlewareStage2, Response<Body>>, HttpError>,
|
||||
> + Send
|
||||
+ Sync,
|
||||
>;
|
||||
pub fn noop<Params: for<'de> Deserialize<'de> + 'static, M: Metadata>() -> DynMiddleware<Params, M>
|
||||
{
|
||||
pub fn noop<M: Metadata>() -> DynMiddleware<M> {
|
||||
Box::new(|_, _| async { Ok(Ok(noop2())) }.boxed())
|
||||
}
|
||||
pub type DynMiddlewareStage2<Params> = Box<
|
||||
pub type DynMiddlewareStage2 = Box<
|
||||
dyn for<'a> FnOnce(
|
||||
&'a mut RpcRequest<GenericRpcMethod<String, Params>>,
|
||||
&'a mut RequestParts,
|
||||
&'a mut RpcRequest<GenericRpcMethod<String, Map<String, Value>>>,
|
||||
) -> BoxFuture<
|
||||
'a,
|
||||
Result<Result<DynMiddlewareStage3, Response<Body>>, HttpError>,
|
||||
> + Send
|
||||
+ Sync,
|
||||
>;
|
||||
pub fn noop2<Params: for<'de> Deserialize<'de> + 'static>() -> DynMiddlewareStage2<Params> {
|
||||
Box::new(|_| async { Ok(Ok(noop3())) }.boxed())
|
||||
pub fn noop2() -> DynMiddlewareStage2 {
|
||||
Box::new(|_, _| async { Ok(Ok(noop3())) }.boxed())
|
||||
}
|
||||
pub type DynMiddlewareStage3 = Box<
|
||||
dyn for<'a> FnOnce(
|
||||
&'a mut ResponseParts,
|
||||
&'a mut Result<Value, RpcError>,
|
||||
) -> BoxFuture<
|
||||
'a,
|
||||
@@ -137,7 +147,7 @@ pub type DynMiddlewareStage3 = Box<
|
||||
+ Sync,
|
||||
>;
|
||||
pub fn noop3() -> DynMiddlewareStage3 {
|
||||
Box::new(|_| async { Ok(Ok(noop4())) }.boxed())
|
||||
Box::new(|_, _| async { Ok(Ok(noop4())) }.boxed())
|
||||
}
|
||||
pub type DynMiddlewareStage4 = Box<
|
||||
dyn for<'a> FnOnce(&'a mut Response<Body>) -> BoxFuture<'a, Result<(), HttpError>>
|
||||
@@ -153,13 +163,15 @@ pub fn constrain_middleware<
|
||||
'b,
|
||||
'c,
|
||||
'd,
|
||||
Params: for<'de> Deserialize<'de> + 'static,
|
||||
M: Metadata,
|
||||
ReqFn: Fn(&'a mut Request<Body>, M) -> ReqFut,
|
||||
ReqFut: Future<Output = Result<Result<RpcReqFn, Response<Body>>, HttpError>> + 'a,
|
||||
RpcReqFn: FnOnce(&'b mut RpcRequest<GenericRpcMethod<String, Params>>) -> RpcReqFut,
|
||||
RpcReqFn: FnOnce(
|
||||
&'b mut RequestParts,
|
||||
&'b mut RpcRequest<GenericRpcMethod<String, Map<String, Value>>>,
|
||||
) -> RpcReqFut,
|
||||
RpcReqFut: Future<Output = Result<Result<RpcResFn, Response<Body>>, HttpError>> + 'b,
|
||||
RpcResFn: FnOnce(&'c mut Result<Value, RpcError>) -> RpcResFut,
|
||||
RpcResFn: FnOnce(&'c mut ResponseParts, &'c mut Result<Value, RpcError>) -> RpcResFut,
|
||||
RpcResFut: Future<Output = Result<Result<ResFn, Response<Body>>, HttpError>> + 'c,
|
||||
ResFn: FnOnce(&'d mut Response<Body>) -> ResFut,
|
||||
ResFut: Future<Output = Result<(), HttpError>> + 'd,
|
||||
|
||||
@@ -91,18 +91,18 @@ fn dothething2<U: Serialize + for<'a> Deserialize<'a> + FromStr<Err = E>, E: Dis
|
||||
))
|
||||
}
|
||||
|
||||
async fn cors<Params: for<'de> Deserialize<'de> + 'static, M: Metadata + 'static>(
|
||||
async fn cors<M: Metadata + 'static>(
|
||||
req: &mut Request<Body>,
|
||||
_: M,
|
||||
) -> Result<Result<DynMiddlewareStage2<Params>, Response<Body>>, HttpError> {
|
||||
) -> Result<Result<DynMiddlewareStage2, Response<Body>>, HttpError> {
|
||||
if req.method() == hyper::Method::OPTIONS {
|
||||
Ok(Err(Response::builder()
|
||||
.header("Access-Control-Allow-Origin", "*")
|
||||
.body(Body::empty())?))
|
||||
} else {
|
||||
Ok(Ok(Box::new(|_| {
|
||||
Ok(Ok(Box::new(|_, _| {
|
||||
async move {
|
||||
let res: DynMiddlewareStage3 = Box::new(|_| {
|
||||
let res: DynMiddlewareStage3 = Box::new(|_, _| {
|
||||
async move {
|
||||
let res: DynMiddlewareStage4 = Box::new(|res| {
|
||||
async move {
|
||||
|
||||
Reference in New Issue
Block a user