diff --git a/rpc-toolkit-macro-internals/src/command/build.rs b/rpc-toolkit-macro-internals/src/command/build.rs index 6c6778b..0d0f91f 100644 --- a/rpc-toolkit-macro-internals/src/command/build.rs +++ b/rpc-toolkit-macro-internals/src/command/build.rs @@ -444,7 +444,9 @@ fn rpc_handler( quote! { args.#field_name } } ParamType::Context(_) => quote! { ctx }, - _ => unreachable!(), + ParamType::Request => quote! { request }, + ParamType::Response => quote! { response }, + ParamType::None => unreachable!(), }); match opt { Options::Leaf(opt) if matches!(opt.exec_ctx, ExecutionContext::CliOnly(_)) => quote! { @@ -452,6 +454,8 @@ fn rpc_handler( pub async fn rpc_handler#fn_generics( _ctx: #ctx_ty, + _request: &::rpc_toolkit::command_helpers::prelude::RequestParts, + _response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts, method: &str, _args: Params#param_ty_generics, ) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> { @@ -480,6 +484,8 @@ fn rpc_handler( pub async fn rpc_handler#fn_generics( ctx: #ctx_ty, + request: &::rpc_toolkit::command_helpers::prelude::RequestParts, + response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts, method: &str, args: Params#param_ty_generics, ) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> { @@ -515,7 +521,7 @@ fn rpc_handler( ), }; quote_spanned!{ subcommand.span() => - [#subcommand::NAME, rest] => #subcommand::#rpc_handler(ctx, rest, ::rpc_toolkit::command_helpers::prelude::from_value(args.rest)?).await + [#subcommand::NAME, rest] => #subcommand::#rpc_handler(ctx, request, response, rest, ::rpc_toolkit::command_helpers::prelude::from_value(args.rest)?).await } }); let subcmd_impl = quote! { @@ -550,6 +556,8 @@ fn rpc_handler( pub async fn rpc_handler#fn_generics( ctx: #ctx_ty, + request: &::rpc_toolkit::command_helpers::prelude::RequestParts, + response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts, method: &str, args: Params#param_ty_generics, ) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> { @@ -569,6 +577,8 @@ fn rpc_handler( pub async fn rpc_handler#fn_generics( ctx: #ctx_ty, + request: &::rpc_toolkit::command_helpers::prelude::RequestParts, + response: &mut ::rpc_toolkit::command_helpers::prelude::ResponseParts, method: &str, args: Params#param_ty_generics, ) -> Result<::rpc_toolkit::command_helpers::prelude::Value, ::rpc_toolkit::command_helpers::prelude::RpcError> { @@ -618,7 +628,9 @@ fn cli_handler( quote! { params.#field_name.clone() } } ParamType::Context(_) => quote! { ctx }, - _ => unreachable!(), + ParamType::Request => quote! { request }, + ParamType::Response => quote! { response }, + ParamType::None => unreachable!(), }); let mut param_generics_filter = GenericFilter::new(fn_generics); for param in params { diff --git a/rpc-toolkit-macro-internals/src/command/mod.rs b/rpc-toolkit-macro-internals/src/command/mod.rs index ba8d4f2..3e458ac 100644 --- a/rpc-toolkit-macro-internals/src/command/mod.rs +++ b/rpc-toolkit-macro-internals/src/command/mod.rs @@ -90,4 +90,6 @@ pub enum ParamType { None, Arg(ArgOptions), Context(Type), + Request, + Response, } diff --git a/rpc-toolkit-macro-internals/src/command/parse.rs b/rpc-toolkit-macro-internals/src/command/parse.rs index 422df68..3779101 100644 --- a/rpc-toolkit-macro-internals/src/command/parse.rs +++ b/rpc-toolkit-macro-internals/src/command/parse.rs @@ -666,6 +666,16 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result> { attr.span(), "`arg` and `context` are mutually exclusive", )); + } else if matches!(ty, ParamType::Request) { + return Err(Error::new( + attr.span(), + "`arg` and `request` are mutually exclusive", + )); + } else if matches!(ty, ParamType::Response) { + return Err(Error::new( + attr.span(), + "`arg` and `response` are mutually exclusive", + )); } } else if param.attrs[i].path.is_ident("context") { let attr = param.attrs.remove(i); @@ -681,6 +691,66 @@ pub fn parse_param_attrs(item: &mut ItemFn) -> Result> { attr.span(), "`arg` and `context` are mutually exclusive", )); + } else if matches!(ty, ParamType::Request) { + return Err(Error::new( + attr.span(), + "`context` and `request` are mutually exclusive", + )); + } else if matches!(ty, ParamType::Response) { + return Err(Error::new( + attr.span(), + "`context` and `response` are mutually exclusive", + )); + } + } else if param.attrs[i].path.is_ident("request") { + let attr = param.attrs.remove(i); + if matches!(ty, ParamType::None) { + ty = ParamType::Request; + } else if matches!(ty, ParamType::Request) { + return Err(Error::new( + attr.span(), + "`request` attribute may only be specified once", + )); + } else if matches!(ty, ParamType::Arg(_)) { + return Err(Error::new( + attr.span(), + "`arg` and `request` are mutually exclusive", + )); + } else if matches!(ty, ParamType::Context(_)) { + return Err(Error::new( + attr.span(), + "`context` and `request` are mutually exclusive", + )); + } else if matches!(ty, ParamType::Response) { + return Err(Error::new( + attr.span(), + "`request` and `response` are mutually exclusive", + )); + } + } else if param.attrs[i].path.is_ident("response") { + let attr = param.attrs.remove(i); + if matches!(ty, ParamType::None) { + ty = ParamType::Response; + } else if matches!(ty, ParamType::Response) { + return Err(Error::new( + attr.span(), + "`response` attribute may only be specified once", + )); + } else if matches!(ty, ParamType::Arg(_)) { + return Err(Error::new( + attr.span(), + "`arg` and `response` are mutually exclusive", + )); + } else if matches!(ty, ParamType::Context(_)) { + return Err(Error::new( + attr.span(), + "`context` and `response` are mutually exclusive", + )); + } else if matches!(ty, ParamType::Request) { + return Err(Error::new( + attr.span(), + "`request` and `response` are mutually exclusive", + )); } } else { i += 1; diff --git a/rpc-toolkit-macro-internals/src/rpc_server/build.rs b/rpc-toolkit-macro-internals/src/rpc_server/build.rs index 48c901c..c590361 100644 --- a/rpc-toolkit-macro-internals/src/rpc_server/build.rs +++ b/rpc-toolkit-macro-internals/src/rpc_server/build.rs @@ -54,29 +54,30 @@ pub fn build(args: RpcServerArgs) -> TokenStream { Err(res) => return Ok(res), }; )* - let rpc_req = ::rpc_toolkit::rpc_server_helpers::make_request(&mut req).await; + let (mut req_parts, req_body) = req.into_parts(); + let (mut res_parts, _) = ::rpc_toolkit::hyper::Response::new(()).into_parts(); + let rpc_req = ::rpc_toolkit::rpc_server_helpers::make_request(&req_parts, req_body).await; match rpc_req { Ok(mut rpc_req) => { #( - let #middleware_name_post = match #middleware_name_pre2(&mut rpc_req).await? { + let #middleware_name_post = match #middleware_name_pre2(&mut req_parts, &mut rpc_req).await? { Ok(a) => a, Err(res) => return Ok(res), }; )* - let mut rpc_res = #command( - ctx, - ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), - rpc_req.params, - ) - .await; + let mut rpc_res = match ::rpc_toolkit::serde_json::from_value(::rpc_toolkit::serde_json::Value::Object(rpc_req.params)) { + Ok(params) => #command(ctx, &req_parts, &mut res_parts, ::rpc_toolkit::yajrc::RpcMethod::as_str(&rpc_req.method), params).await, + Err(e) => Err(e.into()) + }; #( - let #middleware_name = match #middleware_name_post_inv(&mut rpc_res).await? { + let #middleware_name = match #middleware_name_post_inv(&mut res_parts, &mut rpc_res).await? { Ok(a) => a, Err(res) => return Ok(res), }; )* let mut res = ::rpc_toolkit::rpc_server_helpers::to_response( - &req, + &req_parts.headers, + res_parts, Ok(( rpc_req.id, rpc_res, @@ -89,7 +90,8 @@ pub fn build(args: RpcServerArgs) -> TokenStream { Ok::<_, ::rpc_toolkit::hyper::http::Error>(res) } Err(e) => ::rpc_toolkit::rpc_server_helpers::to_response( - &req, + &req_parts.headers, + res_parts, Err(e), status_fn, ), diff --git a/rpc-toolkit/src/command_helpers.rs b/rpc-toolkit/src/command_helpers.rs index 23e8eca..30c06e3 100644 --- a/rpc-toolkit/src/command_helpers.rs +++ b/rpc-toolkit/src/command_helpers.rs @@ -16,6 +16,8 @@ pub mod prelude { pub use std::marker::PhantomData; pub use clap::{App, AppSettings, Arg, ArgMatches}; + pub use hyper::http::request::Parts as RequestParts; + pub use hyper::http::response::Parts as ResponseParts; pub use serde::{Deserialize, Serialize}; pub use serde_json::{from_value, to_value, Value}; pub use tokio::runtime::Runtime; diff --git a/rpc-toolkit/src/rpc_server_helpers.rs b/rpc-toolkit/src/rpc_server_helpers.rs index 174b435..bf8b2b3 100644 --- a/rpc-toolkit/src/rpc_server_helpers.rs +++ b/rpc-toolkit/src/rpc_server_helpers.rs @@ -3,13 +3,15 @@ use std::future::Future; use futures::future::BoxFuture; use futures::FutureExt; use hyper::body::Buf; +use hyper::header::HeaderValue; +use hyper::http::request::Parts as RequestParts; +use hyper::http::response::Parts as ResponseParts; use hyper::http::Error as HttpError; use hyper::server::conn::AddrIncoming; use hyper::server::{Builder, Server}; -use hyper::{Body, Request, Response, StatusCode}; +use hyper::{Body, HeaderMap, Request, Response, StatusCode}; use lazy_static::lazy_static; -use serde::Deserialize; -use serde_json::Value; +use serde_json::{Map, Value}; use url::Host; use yajrc::{AnyRpcMethod, GenericRpcMethod, Id, RpcError, RpcRequest, RpcResponse}; @@ -33,16 +35,15 @@ pub fn make_builder(ctx: &Ctx) -> Builder { Server::bind(&addr) } -pub async fn make_request Deserialize<'de> + 'static>( - req: &mut Request, -) -> Result>, RpcError> { - let body = hyper::body::aggregate(std::mem::replace(req.body_mut(), Body::empty())) - .await? - .reader(); - let rpc_req: RpcRequest>; +pub async fn make_request( + req_parts: &RequestParts, + req_body: Body, +) -> Result>>, RpcError> { + let body = hyper::body::aggregate(req_body).await?.reader(); + let rpc_req: RpcRequest>>; #[cfg(feature = "cbor")] - if req - .headers() + if req_parts + .headers .get("content-type") .and_then(|h| h.to_str().ok()) == Some("application/cbor") @@ -60,7 +61,8 @@ pub async fn make_request Deserialize<'de> + 'static>( } pub fn to_response StatusCode>( - req: &Request, + req_headers: &HeaderMap, + mut res_parts: ResponseParts, res: Result<(Option, Result), RpcError>, status_code_fn: F, ) -> Result, HttpError> { @@ -69,10 +71,8 @@ pub fn to_response StatusCode>( Err(e) => e.into(), }; let body; - let mut res = Response::builder(); #[cfg(feature = "cbor")] - if req - .headers() + if req_headers .get("accept") .and_then(|h| h.to_str().ok()) .iter() @@ -81,54 +81,64 @@ pub fn to_response StatusCode>( .any(|s| s == "application/cbor") // prefer cbor if accepted { - res = res.header("content-type", "application/cbor"); + res_parts + .headers + .insert("content-type", HeaderValue::from_static("application/cbor")); body = serde_cbor::to_vec(&rpc_res).unwrap_or_else(|_| CBOR_INTERNAL_ERROR.clone()); } else { - res = res.header("content-type", "application/json"); + res_parts + .headers + .insert("content-type", HeaderValue::from_static("application/json")); body = serde_json::to_vec(&rpc_res).unwrap_or_else(|_| JSON_INTERNAL_ERROR.clone()); } #[cfg(not(feature = "cbor"))] { - res.header("content-type", "application/json"); + res_parts + .headers + .insert("content-type", HeaderValue::from_static("application/json")); body = serde_json::to_vec(&rpc_res).unwrap_or_else(|_| JSON_INTERNAL_ERROR.clone()); } - res = res.header("content-length", body.len()); - res = res.status(match &rpc_res.result { + res_parts.headers.insert( + "content-length", + HeaderValue::from_str(&format!("{}", body.len()))?, + ); + res_parts.status = match &rpc_res.result { Ok(_) => StatusCode::OK, Err(e) => status_code_fn(e.code), - }); - res.body(Body::from(body)) + }; + Ok(Response::from_parts(res_parts, body.into())) } // &mut Request -> Result -> Future -> Future>, Response>, HttpError>>>, Response>, HttpError> -pub type DynMiddleware = Box< +pub type DynMiddleware = Box< dyn for<'a> FnOnce( &'a mut Request, Metadata, ) -> BoxFuture< 'a, - Result, Response>, HttpError>, + Result>, HttpError>, > + Send + Sync, >; -pub fn noop Deserialize<'de> + 'static, M: Metadata>() -> DynMiddleware -{ +pub fn noop() -> DynMiddleware { Box::new(|_, _| async { Ok(Ok(noop2())) }.boxed()) } -pub type DynMiddlewareStage2 = Box< +pub type DynMiddlewareStage2 = Box< dyn for<'a> FnOnce( - &'a mut RpcRequest>, + &'a mut RequestParts, + &'a mut RpcRequest>>, ) -> BoxFuture< 'a, Result>, HttpError>, > + Send + Sync, >; -pub fn noop2 Deserialize<'de> + 'static>() -> DynMiddlewareStage2 { - Box::new(|_| async { Ok(Ok(noop3())) }.boxed()) +pub fn noop2() -> DynMiddlewareStage2 { + Box::new(|_, _| async { Ok(Ok(noop3())) }.boxed()) } pub type DynMiddlewareStage3 = Box< dyn for<'a> FnOnce( + &'a mut ResponseParts, &'a mut Result, ) -> BoxFuture< 'a, @@ -137,7 +147,7 @@ pub type DynMiddlewareStage3 = Box< + Sync, >; pub fn noop3() -> DynMiddlewareStage3 { - Box::new(|_| async { Ok(Ok(noop4())) }.boxed()) + Box::new(|_, _| async { Ok(Ok(noop4())) }.boxed()) } pub type DynMiddlewareStage4 = Box< dyn for<'a> FnOnce(&'a mut Response) -> BoxFuture<'a, Result<(), HttpError>> @@ -153,13 +163,15 @@ pub fn constrain_middleware< 'b, 'c, 'd, - Params: for<'de> Deserialize<'de> + 'static, M: Metadata, ReqFn: Fn(&'a mut Request, M) -> ReqFut, ReqFut: Future>, HttpError>> + 'a, - RpcReqFn: FnOnce(&'b mut RpcRequest>) -> RpcReqFut, + RpcReqFn: FnOnce( + &'b mut RequestParts, + &'b mut RpcRequest>>, + ) -> RpcReqFut, RpcReqFut: Future>, HttpError>> + 'b, - RpcResFn: FnOnce(&'c mut Result) -> RpcResFut, + RpcResFn: FnOnce(&'c mut ResponseParts, &'c mut Result) -> RpcResFut, RpcResFut: Future>, HttpError>> + 'c, ResFn: FnOnce(&'d mut Response) -> ResFut, ResFut: Future> + 'd, diff --git a/rpc-toolkit/tests/test.rs b/rpc-toolkit/tests/test.rs index 65018ec..da3f6de 100644 --- a/rpc-toolkit/tests/test.rs +++ b/rpc-toolkit/tests/test.rs @@ -91,18 +91,18 @@ fn dothething2 Deserialize<'a> + FromStr, E: Dis )) } -async fn cors Deserialize<'de> + 'static, M: Metadata + 'static>( +async fn cors( req: &mut Request, _: M, -) -> Result, Response>, HttpError> { +) -> Result>, HttpError> { if req.method() == hyper::Method::OPTIONS { Ok(Err(Response::builder() .header("Access-Control-Allow-Origin", "*") .body(Body::empty())?)) } else { - Ok(Ok(Box::new(|_| { + Ok(Ok(Box::new(|_, _| { async move { - let res: DynMiddlewareStage3 = Box::new(|_| { + let res: DynMiddlewareStage3 = Box::new(|_, _| { async move { let res: DynMiddlewareStage4 = Box::new(|res| { async move {