diff --git a/rpc-toolkit/src/cli.rs b/rpc-toolkit/src/cli.rs index da2ce96..fdee58a 100644 --- a/rpc-toolkit/src/cli.rs +++ b/rpc-toolkit/src/cli.rs @@ -1,4 +1,5 @@ use std::any::TypeId; +use std::collections::VecDeque; use std::ffi::OsString; use std::marker::PhantomData; @@ -63,14 +64,14 @@ impl let (method, params) = root_handler.cli_parse(&matches, ctx_ty)?; let res = root_handler.handle_sync(HandleAnyArgs { context: ctx.clone().upcast(), - parent_method: Vec::new(), + parent_method: VecDeque::new(), method: method.clone(), params: params.clone(), })?; root_handler.cli_display( HandleAnyArgs { context: ctx.upcast(), - parent_method: Vec::new(), + parent_method: VecDeque::new(), method, params, }, diff --git a/rpc-toolkit/src/handler.rs b/rpc-toolkit/src/handler.rs index 6086342..676ccba 100644 --- a/rpc-toolkit/src/handler.rs +++ b/rpc-toolkit/src/handler.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use clap::{ArgMatches, Command, CommandFactory, FromArgMatches, Parser}; use futures::Future; +use imbl_value::imbl::OrdMap; use imbl_value::Value; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; @@ -18,7 +19,7 @@ use crate::{CallRemote, CallRemoteHandler}; pub(crate) struct HandleAnyArgs { pub(crate) context: AnyContext, - pub(crate) parent_method: Vec<&'static str>, + pub(crate) parent_method: VecDeque<&'static str>, pub(crate) method: VecDeque<&'static str>, pub(crate) params: Value, } @@ -53,6 +54,11 @@ impl HandleAnyArgs { pub(crate) trait HandleAny: std::fmt::Debug + Send + Sync { fn handle_sync(&self, handle_args: HandleAnyArgs) -> Result; async fn handle_async(&self, handle_args: HandleAnyArgs) -> Result; + fn metadata( + &self, + method: VecDeque<&'static str>, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value>; fn method_from_dots(&self, method: &str, ctx_ty: TypeId) -> Option>; } #[async_trait::async_trait] @@ -63,6 +69,13 @@ impl HandleAny for Arc { async fn handle_async(&self, handle_args: HandleAnyArgs) -> Result { self.deref().handle_async(handle_args).await } + fn metadata( + &self, + method: VecDeque<&'static str>, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value> { + self.deref().metadata(method, ctx_ty) + } fn method_from_dots(&self, method: &str, ctx_ty: TypeId) -> Option> { self.deref().method_from_dots(method, ctx_ty) } @@ -122,6 +135,16 @@ impl HandleAny for DynHandler { DynHandler::WithCli(h) => h.handle_async(handle_args).await, } } + fn metadata( + &self, + method: VecDeque<&'static str>, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value> { + match self { + DynHandler::WithoutCli(h) => h.metadata(method, ctx_ty), + DynHandler::WithCli(h) => h.metadata(method, ctx_ty), + } + } fn method_from_dots(&self, method: &str, ctx_ty: TypeId) -> Option> { match self { DynHandler::WithoutCli(h) => h.method_from_dots(method, ctx_ty), @@ -133,7 +156,7 @@ impl HandleAny for DynHandler { #[derive(Debug, Clone)] pub struct HandleArgs { pub context: Context, - pub parent_method: Vec<&'static str>, + pub parent_method: VecDeque<&'static str>, pub method: VecDeque<&'static str>, pub params: H::Params, pub inherited_params: H::InheritedParams, @@ -179,6 +202,14 @@ pub trait Handler: .await .unwrap() } + #[allow(unused_variables)] + fn metadata( + &self, + method: VecDeque<&'static str>, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value> { + OrdMap::new() + } fn contexts(&self) -> Option> { Context::type_ids_for(self) } @@ -235,6 +266,13 @@ where ) .map_err(internal_error) } + fn metadata( + &self, + method: VecDeque<&'static str>, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value> { + self.handler.metadata(method, ctx_ty) + } fn method_from_dots(&self, method: &str, ctx_ty: TypeId) -> Option> { self.handler.method_from_dots(method, ctx_ty) } @@ -318,20 +356,27 @@ impl SubcommandMap { pub struct ParentHandler { _phantom: PhantomData<(Params, InheritedParams)>, pub(crate) subcommands: SubcommandMap, + metadata: OrdMap<&'static str, Value>, } impl ParentHandler { pub fn new() -> Self { Self { _phantom: PhantomData, subcommands: SubcommandMap(BTreeMap::new()), + metadata: OrdMap::new(), } } + pub fn with_metadata(mut self, key: &'static str, value: Value) -> Self { + self.metadata.insert(key, value); + self + } } impl Clone for ParentHandler { fn clone(&self) -> Self { Self { _phantom: PhantomData, subcommands: self.subcommands.clone(), + metadata: self.metadata.clone(), } } } @@ -431,6 +476,19 @@ where raw_params, }) } + fn metadata( + &self, + method: VecDeque<&'static str>, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value> { + self.handler.metadata(method, ctx_ty) + } + fn contexts(&self) -> Option> { + self.handler.contexts() + } + fn method_from_dots(&self, method: &str, ctx_ty: TypeId) -> Option> { + self.handler.method_from_dots(method, ctx_ty) + } } impl CliBindings @@ -742,7 +800,7 @@ where ) -> Result { let cmd = method.pop_front(); if let Some(cmd) = cmd { - parent_method.push(cmd); + parent_method.push_back(cmd); } if let Some((_, sub_handler)) = &self.subcommands.get(context.inner_type_id(), cmd) { sub_handler.handle_sync(HandleAnyArgs { @@ -767,7 +825,7 @@ where ) -> Result { let cmd = method.pop_front(); if let Some(cmd) = cmd { - parent_method.push(cmd); + parent_method.push_back(cmd); } if let Some((_, sub_handler)) = self.subcommands.get(context.inner_type_id(), cmd) { sub_handler @@ -782,6 +840,18 @@ where Err(yajrc::METHOD_NOT_FOUND_ERROR) } } + fn metadata( + &self, + mut method: VecDeque<&'static str>, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value> { + let mut metadata = self.metadata.clone(); + if let Some((_, handler)) = self.subcommands.get(ctx_ty, method.pop_front()) { + handler.metadata(method, ctx_ty).union(metadata) + } else { + metadata + } + } fn contexts(&self) -> Option> { let mut set = BTreeSet::new(); for ctx_ty in self.subcommands.0.values().flat_map(|c| c.keys()) { @@ -872,7 +942,7 @@ where ) -> Result<(), Self::Err> { let cmd = method.pop_front(); if let Some(cmd) = cmd { - parent_method.push(cmd); + parent_method.push_back(cmd); } if let Some((_, DynHandler::WithCli(sub_handler))) = self.subcommands.get(context.inner_type_id(), cmd) @@ -896,6 +966,13 @@ pub struct FromFn { _phantom: PhantomData<(T, E, Args)>, function: F, blocking: bool, + metadata: OrdMap<&'static str, Value>, +} +impl FromFn { + pub fn with_metadata(mut self, key: &'static str, value: Value) -> Self { + self.metadata.insert(key, value); + self + } } impl Clone for FromFn { fn clone(&self) -> Self { @@ -903,6 +980,7 @@ impl Clone for FromFn { _phantom: PhantomData, function: self.function.clone(), blocking: self.blocking, + metadata: self.metadata.clone(), } } } @@ -929,6 +1007,7 @@ pub fn from_fn(function: F) -> FromFn { function, _phantom: PhantomData, blocking: false, + metadata: OrdMap::new(), } } @@ -937,18 +1016,21 @@ pub fn from_fn_blocking(function: F) -> FromFn { function, _phantom: PhantomData, blocking: true, + metadata: OrdMap::new(), } } pub struct FromFnAsync { _phantom: PhantomData<(Fut, T, E, Args)>, function: F, + metadata: OrdMap<&'static str, Value>, } impl Clone for FromFnAsync { fn clone(&self) -> Self { Self { _phantom: PhantomData, function: self.function.clone(), + metadata: self.metadata.clone(), } } } @@ -972,6 +1054,7 @@ pub fn from_fn_async(function: F) -> FromFnAsync, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value> { + self.metadata.clone() + } } impl HandlerTypes for FromFnAsync where @@ -1032,6 +1122,13 @@ where async fn handle_async(&self, _: HandleArgs) -> Result { (self.function)().await } + fn metadata( + &self, + mut method: VecDeque<&'static str>, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value> { + self.metadata.clone() + } } impl HandlerTypes for FromFn @@ -1067,6 +1164,13 @@ where self.handle_async_with_sync(handle_args).await } } + fn metadata( + &self, + mut method: VecDeque<&'static str>, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value> { + self.metadata.clone() + } } impl HandlerTypes for FromFnAsync where @@ -1096,6 +1200,13 @@ where ) -> Result { (self.function)(handle_args.context).await } + fn metadata( + &self, + mut method: VecDeque<&'static str>, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value> { + self.metadata.clone() + } } impl HandlerTypes for FromFn @@ -1136,6 +1247,13 @@ where self.handle_async_with_sync(handle_args).await } } + fn metadata( + &self, + mut method: VecDeque<&'static str>, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value> { + self.metadata.clone() + } } impl HandlerTypes for FromFnAsync @@ -1172,6 +1290,13 @@ where } = handle_args; (self.function)(context, params).await } + fn metadata( + &self, + mut method: VecDeque<&'static str>, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value> { + self.metadata.clone() + } } impl HandlerTypes @@ -1219,6 +1344,13 @@ where self.handle_async_with_sync(handle_args).await } } + fn metadata( + &self, + mut method: VecDeque<&'static str>, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value> { + self.metadata.clone() + } } impl HandlerTypes @@ -1261,6 +1393,13 @@ where } = handle_args; (self.function)(context, params, inherited_params).await } + fn metadata( + &self, + mut method: VecDeque<&'static str>, + ctx_ty: TypeId, + ) -> OrdMap<&'static str, Value> { + self.metadata.clone() + } } impl CliBindings for FromFn diff --git a/rpc-toolkit/src/server/http.rs b/rpc-toolkit/src/server/http.rs index ab40aeb..00f092e 100644 --- a/rpc-toolkit/src/server/http.rs +++ b/rpc-toolkit/src/server/http.rs @@ -1,60 +1,287 @@ +use std::any::TypeId; use std::task::Context; -use futures::future::BoxFuture; +use futures::future::{join_all, BoxFuture}; +use futures::FutureExt; +use http::header::{CONTENT_LENGTH, CONTENT_TYPE}; use http::request::Parts; +use http_body_util::BodyExt; use hyper::body::{Bytes, Incoming}; +use hyper::service::Service; use hyper::{Request, Response}; -use yajrc::{RpcRequest, RpcResponse}; +use imbl_value::Value; +use serde::de::DeserializeOwned; +use serde::Serialize; +use yajrc::{RpcError, RpcMethod}; + +use crate::server::{RpcRequest, RpcResponse, SingleOrBatchRpcRequest}; +use crate::util::{internal_error, parse_error}; +use crate::{handler, HandleAny, Server}; + +const FALLBACK_ERROR: &str = "{\"error\":{\"code\":-32603,\"message\":\"Internal error\",\"data\":\"Failed to serialize rpc response\"}}"; + +pub fn fallback_rpc_error_response() -> Response { + Response::builder() + .header(CONTENT_TYPE, "application/json") + .header(CONTENT_LENGTH, FALLBACK_ERROR.len()) + .body(Bytes::from_static(FALLBACK_ERROR.as_bytes())) + .unwrap() +} + +pub fn json_http_response(t: &T) -> Response { + let body = match serde_json::to_vec(t) { + Ok(a) => a, + Err(_) => return fallback_rpc_error_response(), + }; + Response::builder() + .header(CONTENT_TYPE, "application/json") + .header(CONTENT_LENGTH, body.len()) + .body(Bytes::from(body)) + .unwrap_or_else(|_| fallback_rpc_error_response()) +} type BoxBody = http_body_util::combinators::BoxBody; #[async_trait::async_trait] -pub trait Middleware { - type ProcessHttpRequestResult; +pub trait Middleware: Clone + Send + Sync + 'static { + type Metadata: DeserializeOwned + Send + 'static; + #[allow(unused_variables)] async fn process_http_request( - &self, - req: &mut Request, - ) -> Result>>; - type ProcessRpcRequestResult; + &mut self, + context: &Context, + request: &mut Request, + ) -> Result<(), Response> { + Ok(()) + } + #[allow(unused_variables)] async fn process_rpc_request( - &self, - prev: Self::ProcessHttpRequestResult, - // metadata: &Context::Metadata, - req: &mut RpcRequest, - ) -> Result; - type ProcessRpcResponseResult; - async fn process_rpc_response( - &self, - prev: Self::ProcessRpcRequestResult, - res: &mut RpcResponse, - ) -> Self::ProcessRpcResponseResult; - async fn process_http_response( - &self, - prev: Self::ProcessRpcResponseResult, - res: &mut Response, - ); + &mut self, + metadata: Self::Metadata, + request: &mut RpcRequest, + ) -> Result<(), RpcResponse> { + Ok(()) + } + #[allow(unused_variables)] + async fn process_rpc_response(&mut self, response: &mut RpcResponse) {} + #[allow(unused_variables)] + async fn process_http_response(&mut self, response: &mut Response) {} } -// pub struct DynMiddleware { -// process_http_request: Box< -// dyn for<'a> Fn( -// &'a mut Request, -// ) -> BoxFuture< -// 'a, -// Result, hyper::Result>>, -// > + Send -// + Sync, -// >, -// } -// type DynProcessRpcRequest<'m, Context: crate::Context> = Box< -// dyn for<'a> FnOnce( -// &'a Context::Metadata, -// &'a mut RpcRequest, -// ) -// -> BoxFuture<'a, Result, DynSkipHandler<'m>>> -// + Send -// + Sync -// + 'm, -// >; -// type DynProcessRpcResponse<'m> = -// Box FnOnce(&'a mut RpcResponse) -> BoxFuture<'a, DynProcessHttpResponse<'m>>>; +#[allow(private_bounds)] +trait _Middleware: Send + Sync { + fn dyn_clone(&self) -> DynMiddleware; + fn process_http_request<'a>( + &'a mut self, + context: &'a Context, + request: &'a mut Request, + ) -> BoxFuture<'a, Result<(), Response>>; + fn process_rpc_request<'a>( + &'a mut self, + metadata: Value, + request: &'a mut RpcRequest, + ) -> BoxFuture<'a, Result<(), RpcResponse>>; + fn process_rpc_response<'a>(&'a mut self, response: &'a mut RpcResponse) -> BoxFuture<'a, ()>; + fn process_http_response<'a>( + &'a mut self, + response: &'a mut Response, + ) -> BoxFuture<'a, ()>; +} +impl + Send + Sync> _Middleware for T { + fn dyn_clone(&self) -> DynMiddleware { + DynMiddleware(Box::new(::clone(&self))) + } + fn process_http_request<'a>( + &'a mut self, + context: &'a Context, + request: &'a mut Request, + ) -> BoxFuture<'a, Result<(), Response>> { + >::process_http_request(self, context, request) + } + fn process_rpc_request<'a>( + &'a mut self, + metadata: Value, + request: &'a mut RpcRequest, + ) -> BoxFuture<'a, Result<(), RpcResponse>> { + >::process_rpc_request( + self, + match imbl_value::from_value(metadata) { + Ok(a) => a, + Err(e) => return async { Err(internal_error(e).into()) }.boxed(), + }, + request, + ) + } + fn process_rpc_response<'a>(&'a mut self, response: &'a mut RpcResponse) -> BoxFuture<'a, ()> { + >::process_rpc_response(self, response) + } + fn process_http_response<'a>( + &'a mut self, + response: &'a mut Response, + ) -> BoxFuture<'a, ()> { + >::process_http_response(self, response) + } +} + +struct DynMiddleware(Box>); +impl Clone for DynMiddleware { + fn clone(&self) -> Self { + self.0.dyn_clone() + } +} + +pub struct HttpServer { + inner: Server, + middleware: Vec>, +} +impl Clone for HttpServer { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + middleware: self.middleware.clone(), + } + } +} +impl Server { + pub fn for_http(self) -> HttpServer { + HttpServer { + inner: self, + middleware: Vec::new(), + } + } + pub fn middleware>(self, middleware: T) -> HttpServer { + self.for_http().middleware(middleware) + } +} +impl HttpServer { + pub fn middleware>(mut self, middleware: T) -> Self { + self.middleware.push(DynMiddleware(Box::new(middleware))); + self + } + async fn process_http_request(&self, mut req: Request) -> Response { + let mut mid = self.middleware.clone(); + match async { + let ctx = (self.inner.make_ctx)().await?; + for middleware in mid.iter_mut().rev() { + if let Err(e) = middleware.0.process_http_request(&ctx, &mut req).await { + return Ok::<_, RpcError>(e); + } + } + let (_, body) = req.into_parts(); + match serde_json::from_slice::( + &*body.collect().await.map_err(internal_error)?.to_bytes(), + ) + .map_err(parse_error)? + { + SingleOrBatchRpcRequest::Single(rpc_req) => { + let mut res = + json_http_response(&self.process_rpc_request(&mut mid, rpc_req).await); + for middleware in mid.iter_mut() { + middleware.0.process_http_response(&mut res).await; + } + Ok(res) + } + SingleOrBatchRpcRequest::Batch(rpc_reqs) => { + let (mids, rpc_res): (Vec<_>, Vec<_>) = + join_all(rpc_reqs.into_iter().map(|rpc_req| async { + let mut mid = mid.clone(); + let res = self.process_rpc_request(&mut mid, rpc_req).await; + (mid, res) + })) + .await + .into_iter() + .unzip(); + let mut res = json_http_response(&rpc_res); + for mut mid in mids.into_iter().fold( + vec![Vec::with_capacity(rpc_res.len()); mid.len()], + |mut acc, x| { + for (idx, middleware) in x.into_iter().enumerate() { + acc[idx].push(middleware); + } + acc + }, + ) { + for middleware in mid.iter_mut() { + middleware.0.process_http_response(&mut res).await; + } + } + Ok(res) + } + } + } + .await + { + Ok(a) => a, + Err(e) => json_http_response(&RpcResponse { + id: None, + result: Err(e), + }), + } + } + async fn process_rpc_request( + &self, + mid: &mut Vec>, + mut req: RpcRequest, + ) -> RpcResponse { + let metadata = Value::Object( + self.inner + .root_handler + .metadata( + match self + .inner + .root_handler + .method_from_dots(req.method.as_str(), TypeId::of::()) + { + Some(a) => a, + None => { + return RpcResponse { + id: req.id, + result: Err(yajrc::METHOD_NOT_FOUND_ERROR), + } + } + }, + TypeId::of::(), + ) + .into_iter() + .map(|(key, value)| (key.into(), value)) + .collect(), + ); + let mut res = async { + for middleware in mid.iter_mut().rev() { + if let Err(res) = middleware + .0 + .process_rpc_request(metadata.clone(), &mut req) + .await + { + return res; + } + } + self.inner.handle_single_request(req).await + } + .await; + for middleware in mid.iter_mut() { + middleware.0.process_rpc_response(&mut res).await; + } + res + } + pub fn handle(&self, req: Request) -> BoxFuture<'static, Response> { + let server = self.clone(); + async move { + server + .process_http_request(req.map(|b| BoxBody::new(b))) + .await + } + .boxed() + } +} + +impl Service> for HttpServer { + type Response = Response; + type Error = hyper::Error; + type Future = futures::future::Map< + BoxFuture<'static, Self::Response>, + fn(Self::Response) -> Result, + >; + fn call(&self, req: Request) -> Self::Future { + self.handle(req).map(Ok) + } +} diff --git a/rpc-toolkit/src/server/mod.rs b/rpc-toolkit/src/server/mod.rs index 5920717..a56f6f9 100644 --- a/rpc-toolkit/src/server/mod.rs +++ b/rpc-toolkit/src/server/mod.rs @@ -1,4 +1,5 @@ use std::any::TypeId; +use std::collections::VecDeque; use std::sync::Arc; use futures::future::{join_all, BoxFuture}; @@ -24,6 +25,14 @@ pub struct Server { make_ctx: Arc BoxFuture<'static, Result> + Send + Sync>, root_handler: Arc>, } +impl Clone for Server { + fn clone(&self) -> Self { + Self { + make_ctx: self.make_ctx.clone(), + root_handler: self.root_handler.clone(), + } + } +} impl Server { pub fn new< MakeCtx: Fn() -> Fut + Send + Sync + 'static, @@ -54,7 +63,7 @@ impl Server { root_handler .handle_async(HandleAnyArgs { context: make_ctx().await?.upcast(), - parent_method: Vec::new(), + parent_method: VecDeque::new(), method: method.ok_or_else(|| yajrc::METHOD_NOT_FOUND_ERROR)?, params, })