diff --git a/rpc-toolkit-macro-internals/src/rpc_server/build.rs b/rpc-toolkit-macro-internals/src/rpc_server/build.rs index c590361..7d9af77 100644 --- a/rpc-toolkit-macro-internals/src/rpc_server/build.rs +++ b/rpc-toolkit-macro-internals/src/rpc_server/build.rs @@ -19,6 +19,14 @@ pub fn build(args: RpcServerArgs) -> TokenStream { let status_fn = args.status_fn.unwrap_or_else(|| { syn::parse2(quote! { |_| ::rpc_toolkit::hyper::StatusCode::OK }).unwrap() }); + let middleware_name_clone = (0..) + .map(|i| Ident::new(&format!("middleware_clone_{}", i), Span::call_site())) + .take(args.middleware.len()); + let middleware_name_clone2 = middleware_name_clone.clone(); + let middleware_name_clone3 = middleware_name_clone.clone(); + let middleware_name_clone4 = middleware_name_clone.clone(); + let middleware_name_clone5 = middleware_name_clone.clone(); + let middleware_name_clone6 = middleware_name_clone.clone(); let middleware_name_pre = (0..) .map(|i| Ident::new(&format!("middleware_pre_{}", i), Span::call_site())) .take(args.middleware.len()); @@ -41,15 +49,24 @@ pub fn build(args: RpcServerArgs) -> TokenStream { let ctx = #ctx; let status_fn = #status_fn; let builder = ::rpc_toolkit::rpc_server_helpers::make_builder(&ctx); + #( + let #middleware_name_clone = ::std::sync::Arc::new(#middleware); + )* let make_svc = ::rpc_toolkit::hyper::service::make_service_fn(move |_| { let ctx = ctx.clone(); + #( + let #middleware_name_clone3 = #middleware_name_clone2.clone(); + )* async move { Ok::<_, ::rpc_toolkit::hyper::Error>(::rpc_toolkit::hyper::service::service_fn(move |mut req| { let ctx = ctx.clone(); let metadata = #command_module::Metadata::default(); + #( + let #middleware_name_clone5 = #middleware_name_clone4.clone(); + )* async move { #( - let #middleware_name_pre = match ::rpc_toolkit::rpc_server_helpers::constrain_middleware(#middleware)(&mut req, metadata).await? { + let #middleware_name_pre = match ::rpc_toolkit::rpc_server_helpers::constrain_middleware(&*#middleware_name_clone6)(&mut req, metadata).await? { Ok(a) => a, Err(res) => return Ok(res), }; diff --git a/rpc-toolkit/src/rpc_server_helpers.rs b/rpc-toolkit/src/rpc_server_helpers.rs index 38f3051..57ddaac 100644 --- a/rpc-toolkit/src/rpc_server_helpers.rs +++ b/rpc-toolkit/src/rpc_server_helpers.rs @@ -1,4 +1,5 @@ use std::future::Future; +use std::sync::Arc; use futures::future::BoxFuture; use futures::FutureExt; @@ -163,7 +164,7 @@ pub fn constrain_middleware< 'c, 'd, M: Metadata, - ReqFn: Fn(&'a mut Request, M) -> ReqFut, + ReqFn: Fn(&'a mut Request, M) -> ReqFut + Clone, ReqFut: Future>, HttpError>> + 'a, RpcReqFn: FnOnce( &'b mut RequestParts, diff --git a/rpc-toolkit/tests/test.rs b/rpc-toolkit/tests/test.rs index da3f6de..8be9688 100644 --- a/rpc-toolkit/tests/test.rs +++ b/rpc-toolkit/tests/test.rs @@ -8,7 +8,8 @@ use rpc_toolkit::clap::Arg; use rpc_toolkit::hyper::http::Error as HttpError; use rpc_toolkit::hyper::{Body, Response}; use rpc_toolkit::rpc_server_helpers::{ - DynMiddlewareStage2, DynMiddlewareStage3, DynMiddlewareStage4, + constrain_middleware, DynMiddleware, DynMiddlewareStage2, DynMiddlewareStage3, + DynMiddlewareStage4, }; use rpc_toolkit::serde::{Deserialize, Serialize}; use rpc_toolkit::url::Host; @@ -215,3 +216,8 @@ fn cli_example() { }), data: () } ) } + +fn type_check() { + let middleware: DynMiddleware = todo!(); + constrain_middleware(&middleware); +}