diff --git a/rpc-toolkit/src/cli.rs b/rpc-toolkit/src/cli.rs index 11f9a9c..9e0c993 100644 --- a/rpc-toolkit/src/cli.rs +++ b/rpc-toolkit/src/cli.rs @@ -15,7 +15,7 @@ use yajrc::{Id, RpcError}; use crate::util::{internal_error, parse_error, PhantomData}; use crate::{ - AnyHandler, CliBindingsAny, DynHandler, HandleAny, HandleAnyArgs, Handler, HandlerArgs, + AnyHandler, CliBindingsAny, DynHandler, Empty, HandleAny, HandleAnyArgs, Handler, HandlerArgs, HandlerArgsFor, HandlerTypes, IntoContext, Name, ParentHandler, PrintCliResult, }; @@ -84,11 +84,12 @@ impl } } -pub trait CallRemote: crate::Context { +pub trait CallRemote: crate::Context { fn call_remote( &self, method: &str, params: Value, + extra: Extra, ) -> impl Future> + Send; } @@ -171,7 +172,7 @@ pub async fn call_remote_socket( .result } -pub struct CallRemoteHandler { +struct CallRemoteHandler { _phantom: PhantomData, handler: RemoteHandler, } @@ -232,7 +233,11 @@ where .collect::>(); match handle_args .context - .call_remote(&full_method.join("."), handle_args.raw_params.clone()) + .call_remote( + &full_method.join("."), + handle_args.raw_params.clone(), + Empty {}, + ) .await { Ok(a) => imbl_value::from_value(a) diff --git a/rpc-toolkit/src/handler/adapters.rs b/rpc-toolkit/src/handler/adapters.rs index 7d9ad75..c5cdf8d 100644 --- a/rpc-toolkit/src/handler/adapters.rs +++ b/rpc-toolkit/src/handler/adapters.rs @@ -10,11 +10,11 @@ use serde::de::DeserializeOwned; use serde::Serialize; use yajrc::RpcError; -use crate::util::{internal_error, Flat, PhantomData}; +use crate::util::{internal_error, invalid_params, without, Flat, PhantomData}; use crate::{ iter_from_ctx_and_handler, AnyContext, AnyHandler, CallRemote, CliBindings, DynHandler, - EitherContext, Handler, HandlerArgs, HandlerArgsFor, HandlerTypes, IntoContext, IntoHandlers, - OrEmpty, PrintCliResult, + EitherContext, Empty, Handler, HandlerArgs, HandlerArgsFor, HandlerTypes, IntoContext, + IntoHandlers, OrEmpty, PrintCliResult, }; pub trait HandlerExt: Handler + Sized { @@ -40,7 +40,7 @@ pub trait HandlerExt: Handler + Sized { ) -> InheritanceHandler where F: Fn(Params, InheritedParams) -> Self::InheritedParams; - fn with_call_remote(self) -> RemoteCaller; + fn with_call_remote(self) -> RemoteCaller; } impl HandlerExt for T { @@ -90,7 +90,7 @@ impl HandlerExt for T { inherit: f, } } - fn with_call_remote(self) -> RemoteCaller { + fn with_call_remote(self) -> RemoteCaller { RemoteCaller { _phantom: PhantomData::new(), handler: self, @@ -452,11 +452,11 @@ where } } -pub struct RemoteCaller { - _phantom: PhantomData, +pub struct RemoteCaller { + _phantom: PhantomData<(Context, Extra)>, handler: H, } -impl Clone for RemoteCaller { +impl Clone for RemoteCaller { fn clone(&self) -> Self { Self { _phantom: PhantomData::new(), @@ -464,29 +464,31 @@ impl Clone for RemoteCaller { } } } -impl Debug for RemoteCaller { +impl Debug for RemoteCaller { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("RemoteCaller").field(&self.handler).finish() } } -impl HandlerTypes for RemoteCaller +impl HandlerTypes for RemoteCaller where H: HandlerTypes, + Extra: Send + Sync + 'static, { - type Params = H::Params; + type Params = Flat; type InheritedParams = H::InheritedParams; type Ok = H::Ok; type Err = H::Err; } -impl Handler for RemoteCaller +impl Handler for RemoteCaller where - Context: CallRemote, + Context: CallRemote, H: Handler, H::Params: Serialize, H::InheritedParams: Serialize, H::Ok: DeserializeOwned, H::Err: From, + Extra: Serialize + Send + Sync + 'static, { type Context = EitherContext; async fn handle_async( @@ -495,7 +497,7 @@ where context, parent_method, method, - params, + params: Flat(params, extra), inherited_params, raw_params, }: HandlerArgsFor, @@ -504,7 +506,11 @@ where EitherContext::C1(context) => { let full_method = parent_method.into_iter().chain(method).collect::>(); match context - .call_remote(&full_method.join("."), raw_params.clone()) + .call_remote( + &full_method.join("."), + without(raw_params, &extra).map_err(invalid_params)?, + extra, + ) .await { Ok(a) => imbl_value::from_value(a) @@ -553,7 +559,7 @@ where context, parent_method, method, - params, + params: Flat(params, _), inherited_params, raw_params, }: HandlerArgsFor, diff --git a/rpc-toolkit/src/util.rs b/rpc-toolkit/src/util.rs index b457e58..c7b71b2 100644 --- a/rpc-toolkit/src/util.rs +++ b/rpc-toolkit/src/util.rs @@ -10,10 +10,21 @@ use serde::{Deserialize, Serialize}; use yajrc::RpcError; pub fn extract(value: &Value) -> Result { - imbl_value::from_value(value.clone()).map_err(|e| RpcError { - data: Some(e.to_string().into()), - ..yajrc::INVALID_PARAMS_ERROR - }) + imbl_value::from_value(value.clone()).map_err(invalid_params) +} + +pub fn without(value: Value, remove: &T) -> Result { + let to_remove = imbl_value::to_value(remove)?; + let (Value::Object(mut value), Value::Object(to_remove)) = (value, to_remove) else { + return Err(imbl_value::Error { + kind: imbl_value::ErrorKind::Serialization, + source: serde_json::Error::custom("params must be object"), + }); + }; + for k in to_remove.keys() { + value.remove(k); + } + Ok(Value::Object(value)) } pub fn combine(v1: Value, v2: Value) -> Result { @@ -103,6 +114,49 @@ where .serialize(serializer) } } +impl clap::CommandFactory for Flat +where + A: clap::CommandFactory, + B: clap::Args, +{ + fn command() -> clap::Command { + B::augment_args(A::command()) + } + fn command_for_update() -> clap::Command { + B::augment_args_for_update(A::command_for_update()) + } +} +impl clap::FromArgMatches for Flat +where + A: clap::FromArgMatches, + B: clap::FromArgMatches, +{ + fn from_arg_matches(matches: &clap::ArgMatches) -> Result { + Ok(Self( + A::from_arg_matches(matches)?, + B::from_arg_matches(matches)?, + )) + } + fn from_arg_matches_mut(matches: &mut clap::ArgMatches) -> Result { + Ok(Self( + A::from_arg_matches_mut(matches)?, + B::from_arg_matches_mut(matches)?, + )) + } + fn update_from_arg_matches(&mut self, matches: &clap::ArgMatches) -> Result<(), clap::Error> { + self.0.update_from_arg_matches(matches)?; + self.1.update_from_arg_matches(matches)?; + Ok(()) + } + fn update_from_arg_matches_mut( + &mut self, + matches: &mut clap::ArgMatches, + ) -> Result<(), clap::Error> { + self.0.update_from_arg_matches_mut(matches)?; + self.1.update_from_arg_matches_mut(matches)?; + Ok(()) + } +} pub fn poll_select_all<'a, T>( futs: &mut Vec>, diff --git a/rpc-toolkit/tests/handler.rs b/rpc-toolkit/tests/handler.rs index ea599c5..5eb3929 100644 --- a/rpc-toolkit/tests/handler.rs +++ b/rpc-toolkit/tests/handler.rs @@ -62,7 +62,7 @@ impl Context for CliContext { } impl CallRemote for CliContext { - async fn call_remote(&self, method: &str, params: Value) -> Result { + async fn call_remote(&self, method: &str, params: Value, _: Empty) -> Result { call_remote_socket( tokio::net::UnixStream::connect(&self.0.host).await.unwrap(), method, @@ -129,7 +129,7 @@ fn make_api() -> ParentHandler { )) }, ) - .with_call_remote::(), + .with_call_remote::(), ) .subcommand( "hello",