From df86610153288b1623a2935bf9d0fae9e39ae5d2 Mon Sep 17 00:00:00 2001 From: Aiden McClelland Date: Fri, 22 Dec 2023 13:19:07 -0700 Subject: [PATCH] fix cli binding inference --- rpc-toolkit/src/handler.rs | 272 ++++++++++++++++------------------- rpc-toolkit/tests/handler.rs | 2 +- 2 files changed, 128 insertions(+), 146 deletions(-) diff --git a/rpc-toolkit/src/handler.rs b/rpc-toolkit/src/handler.rs index 7187348..e93846a 100644 --- a/rpc-toolkit/src/handler.rs +++ b/rpc-toolkit/src/handler.rs @@ -99,123 +99,6 @@ pub trait PrintCliResult: Handler { ) -> Result<(), Self::Err>; } -#[derive(Debug)] -struct WithCliBindings { - _ctx: PhantomData, - handler: H, -} -impl Clone for WithCliBindings { - fn clone(&self) -> Self { - Self { - _ctx: PhantomData, - handler: self.handler.clone(), - } - } -} - -#[async_trait::async_trait] -impl Handler for WithCliBindings -where - Context: IntoContext, - H: Handler, -{ - type Params = H::Params; - type InheritedParams = H::InheritedParams; - type Ok = H::Ok; - type Err = H::Err; - fn handle_sync( - &self, - HandleArgs { - context, - parent_method, - method, - params, - inherited_params, - raw_params, - }: HandleArgs, - ) -> Result { - self.handler.handle_sync(HandleArgs { - context, - parent_method, - method, - params, - inherited_params, - raw_params, - }) - } - async fn handle_async( - &self, - HandleArgs { - context, - parent_method, - method, - params, - inherited_params, - raw_params, - }: HandleArgs, - ) -> Result { - self.handler - .handle_async(HandleArgs { - context, - parent_method, - method, - params, - inherited_params, - raw_params, - }) - .await - } -} - -impl CliBindings for WithCliBindings -where - Context: IntoContext, - H: Handler, - H::Params: FromArgMatches + CommandFactory + Serialize, - H: PrintCliResult, -{ - fn cli_command(&self, _: TypeId) -> Command { - H::Params::command() - } - fn cli_parse( - &self, - matches: &ArgMatches, - _: TypeId, - ) -> Result<(VecDeque<&'static str>, Value), clap::Error> { - H::Params::from_arg_matches(matches).and_then(|a| { - Ok(( - VecDeque::new(), - imbl_value::to_value(&a) - .map_err(|e| clap::Error::raw(clap::error::ErrorKind::ValueValidation, e))?, - )) - }) - } - fn cli_display( - &self, - HandleArgs { - context, - parent_method, - method, - params, - inherited_params, - raw_params, - }: HandleArgs, - result: Self::Ok, - ) -> Result<(), Self::Err> { - self.handler.print( - HandleArgs { - context, - parent_method, - method, - params, - inherited_params, - raw_params, - }, - result, - ) - } -} - pub(crate) trait HandleAnyWithCli: HandleAny + CliBindingsAny {} impl HandleAnyWithCli for T {} @@ -514,16 +397,26 @@ where } } -impl PrintCliResult +impl CliBindings for InheritanceHandler where Context: IntoContext, Params: Send + Sync + 'static, InheritedParams: Send + Sync + 'static, - H: Handler + PrintCliResult, + H: CliBindings, F: Fn(Params, InheritedParams) -> H::InheritedParams + Send + Sync + Clone + 'static, { - fn print( + fn cli_command(&self, ctx_ty: TypeId) -> Command { + self.handler.cli_command(ctx_ty) + } + fn cli_parse( + &self, + matches: &ArgMatches, + ctx_ty: TypeId, + ) -> Result<(VecDeque<&'static str>, Value), clap::Error> { + self.handler.cli_parse(matches, ctx_ty) + } + fn cli_display( &self, HandleArgs { context, @@ -535,7 +428,7 @@ where }: HandleArgs, result: Self::Ok, ) -> Result<(), Self::Err> { - self.handler.print( + self.handler.cli_display( HandleArgs { context, parent_method, @@ -553,8 +446,8 @@ impl ParentHandler(mut self, name: &'static str, handler: H) -> Self where Context: IntoContext, - H: Handler + PrintCliResult + 'static, - H::Params: FromArgMatches + CommandFactory + Serialize + DeserializeOwned, + H: CliBindings + 'static, + H::Params: DeserializeOwned, H::Ok: Serialize + DeserializeOwned, RpcError: From, { @@ -563,10 +456,7 @@ impl ParentHandler Self where Context: IntoContext, - H: Handler + PrintCliResult + 'static, - H::Params: FromArgMatches + CommandFactory + Serialize + DeserializeOwned, + H: CliBindings + 'static, + H::Params: DeserializeOwned, H::Ok: Serialize + DeserializeOwned, RpcError: From, F: Fn(Params, InheritedParams) -> H::InheritedParams + Send + Sync + Clone + 'static, @@ -614,13 +504,10 @@ where name.into(), DynHandler::WithCli(Arc::new(AnyHandler { _ctx: PhantomData, - handler: WithCliBindings { - _ctx: PhantomData, - handler: InheritanceHandler:: { - _phantom: PhantomData, - handler, - inherit, - }, + handler: InheritanceHandler:: { + _phantom: PhantomData, + handler, + inherit, }, })), ); @@ -657,8 +544,8 @@ where pub fn root_handler(mut self, handler: H, inherit: F) -> Self where Context: IntoContext, - H: Handler + PrintCliResult + 'static, - H::Params: FromArgMatches + CommandFactory + Serialize + DeserializeOwned, + H: CliBindings + 'static, + H::Params: DeserializeOwned, H::Ok: Serialize + DeserializeOwned, RpcError: From, F: Fn(Params, InheritedParams) -> H::InheritedParams + Send + Sync + Clone + 'static, @@ -668,13 +555,10 @@ where None, DynHandler::WithCli(Arc::new(AnyHandler { _ctx: PhantomData, - handler: WithCliBindings { - _ctx: PhantomData, - handler: InheritanceHandler:: { - _phantom: PhantomData, - handler, - inherit, - }, + handler: InheritanceHandler:: { + _phantom: PhantomData, + handler, + inherit, }, })), ); @@ -1160,3 +1044,101 @@ where (self.function)(context, params, inherited_params).await } } + +impl CliBindings for FromFn +where + Context: IntoContext, + Self: Handler, + Self::Params: FromArgMatches + CommandFactory + Serialize, + Self: PrintCliResult, +{ + fn cli_command(&self, _: TypeId) -> Command { + Self::Params::command() + } + fn cli_parse( + &self, + matches: &ArgMatches, + _: TypeId, + ) -> Result<(VecDeque<&'static str>, Value), clap::Error> { + Self::Params::from_arg_matches(matches).and_then(|a| { + Ok(( + VecDeque::new(), + imbl_value::to_value(&a) + .map_err(|e| clap::Error::raw(clap::error::ErrorKind::ValueValidation, e))?, + )) + }) + } + fn cli_display( + &self, + HandleArgs { + context, + parent_method, + method, + params, + inherited_params, + raw_params, + }: HandleArgs, + result: Self::Ok, + ) -> Result<(), Self::Err> { + self.print( + HandleArgs { + context, + parent_method, + method, + params, + inherited_params, + raw_params, + }, + result, + ) + } +} + +impl CliBindings for FromFnAsync +where + Context: IntoContext, + Self: Handler, + Self::Params: FromArgMatches + CommandFactory + Serialize, + Self: PrintCliResult, +{ + fn cli_command(&self, _: TypeId) -> Command { + Self::Params::command() + } + fn cli_parse( + &self, + matches: &ArgMatches, + _: TypeId, + ) -> Result<(VecDeque<&'static str>, Value), clap::Error> { + Self::Params::from_arg_matches(matches).and_then(|a| { + Ok(( + VecDeque::new(), + imbl_value::to_value(&a) + .map_err(|e| clap::Error::raw(clap::error::ErrorKind::ValueValidation, e))?, + )) + }) + } + fn cli_display( + &self, + HandleArgs { + context, + parent_method, + method, + params, + inherited_params, + raw_params, + }: HandleArgs, + result: Self::Ok, + ) -> Result<(), Self::Err> { + self.print( + HandleArgs { + context, + parent_method, + method, + params, + inherited_params, + raw_params, + }, + result, + ) + } +} diff --git a/rpc-toolkit/tests/handler.rs b/rpc-toolkit/tests/handler.rs index b45937c..6621f36 100644 --- a/rpc-toolkit/tests/handler.rs +++ b/rpc-toolkit/tests/handler.rs @@ -3,7 +3,7 @@ use std::path::PathBuf; use std::sync::Arc; use clap::Parser; -use rpc_toolkit::{from_fn, from_fn_async, AnyContext, CliApp, Context, ParentHandler}; +use rpc_toolkit::{from_fn, from_fn_async, AnyContext, CliApp, Context, NoParams, ParentHandler}; use serde::{Deserialize, Serialize}; use tokio::runtime::{Handle, Runtime}; use tokio::sync::OnceCell;