diff --git a/rpc-toolkit/src/cli.rs b/rpc-toolkit/src/cli.rs index 232002f..6f11f9f 100644 --- a/rpc-toolkit/src/cli.rs +++ b/rpc-toolkit/src/cli.rs @@ -1,232 +1,82 @@ -use std::{ffi::OsString, marker::PhantomData}; +use std::any::TypeId; +use std::ffi::OsString; +use std::marker::PhantomData; -use clap::{ArgMatches, CommandFactory, FromArgMatches}; -use futures::{future::BoxFuture, never::Never}; -use futures::{Future, FutureExt}; +use clap::{CommandFactory, FromArgMatches}; use imbl_value::Value; use reqwest::{Client, Method}; use serde::de::DeserializeOwned; use serde::Serialize; -use std::marker::PhantomData; use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; use url::Url; use yajrc::{Id, RpcError}; -use crate::{command::ParentCommand, CliBindings, EmptyHandler, HandleArgs, Handler, NoParams}; -// use crate::command::{AsyncCommand, DynCommand, LeafCommand, ParentInfo}; -use crate::util::{combine, internal_error, invalid_params, parse_error}; -// use crate::{CliBindings, SyncCommand}; +use crate::util::{internal_error, parse_error, Flat}; +use crate::{ + AnyHandler, CliBindingsAny, DynHandler, HandleAny, HandleAnyArgs, HandleArgs, Handler, + IntoContext, Name, ParentHandler, +}; type GenericRpcMethod<'a> = yajrc::GenericRpcMethod<&'a str, Value, Value>; type RpcRequest<'a> = yajrc::RpcRequest>; type RpcResponse<'a> = yajrc::RpcResponse>; -// impl DynCommand { -// fn cli_app(&self) -> Option { -// if let Some(cli) = &self.cli { -// Some( -// cli.cmd -// .clone() -// .name(self.name) -// .subcommands(self.subcommands.iter().filter_map(|c| c.cli_app())), -// ) -// } else { -// None -// } -// } -// fn cmd_from_cli_matches( -// &self, -// matches: &ArgMatches, -// parent: ParentInfo, -// ) -> Result<(Vec<&'static str>, Value, &DynCommand), RpcError> { -// let params = combine( -// parent.params, -// (self -// .cli -// .as_ref() -// .ok_or(yajrc::METHOD_NOT_FOUND_ERROR)? -// .parser)(matches)?, -// )?; -// if let Some((cmd, matches)) = matches.subcommand() { -// let mut method = parent.method; -// method.push(self.name); -// self.subcommands -// .iter() -// .find(|c| c.name == cmd) -// .ok_or(yajrc::METHOD_NOT_FOUND_ERROR)? -// .cmd_from_cli_matches(matches, ParentInfo { method, params }) -// } else { -// Ok((parent.method, params, self)) -// } -// } -// } - -struct RootCliHandler( - PhantomData<(Context, Config)>, -); -impl Handler - for RootCliHandler +pub struct CliApp { + _phantom: PhantomData<(Context, Config)>, + make_ctx: Box Result + Send + Sync>, + root_handler: ParentHandler, +} +impl + CliApp { - type Params = NoParams; - type InheritedParams = NoParams; - type Ok = Never; - type Err = RpcError; - fn handle_sync(&self, _: HandleArgs) -> Result { - Err(yajrc::METHOD_NOT_FOUND_ERROR) - } -} -impl CliBindings - for RootCliHandler -{ - fn cli_command(&self) -> clap::Command { - Config::command() - } - - fn cli_parse( - &self, - matches: &ArgMatches, - ) -> Result<(std::collections::VecDeque<&'static str>, Value), clap::Error> { - } - - fn cli_display( - &self, - handle_args: HandleArgs, - result: Self::Ok, - ) -> Result<(), Self::Err> { - todo!() - } -} - -struct CliApp( - ParentCommand>, -); -impl CliApp { - pub fn new(commands: Vec>) -> Self { - Self { - cli: CliBindings::from_parent::(), - commands, - } - } - fn cmd_from_cli_matches( - &self, - matches: &ArgMatches, - ) -> Result<(Vec<&'static str>, Value, &DynCommand), RpcError> { - if let Some((cmd, matches)) = matches.subcommand() { - Ok(self - .commands - .iter() - .find(|c| c.name == cmd) - .ok_or(yajrc::METHOD_NOT_FOUND_ERROR)? - .cmd_from_cli_matches( - matches, - ParentInfo { - method: Vec::new(), - params: Value::Object(Default::default()), - }, - )?) - } else { - Err(yajrc::METHOD_NOT_FOUND_ERROR) - } - } -} - -pub struct CliAppAsync { - app: CliApp, - make_ctx: Box BoxFuture<'static, Result> + Send>, -} -impl CliAppAsync { - pub fn new< - Cmd: FromArgMatches + CommandFactory + Serialize + DeserializeOwned + Send, - F: FnOnce(Cmd) -> Fut + Send + 'static, - Fut: Future> + Send, - >( - make_ctx: F, - commands: Vec>, + pub fn new Result + Send + Sync + 'static>( + make_ctx: MakeCtx, + root_handler: ParentHandler, ) -> Self { Self { - app: CliApp::new::(commands), - make_ctx: Box::new(|args| { - async { make_ctx(imbl_value::from_value(args).map_err(parse_error)?).await }.boxed() - }), + _phantom: PhantomData, + make_ctx: Box::new(make_ctx), + root_handler, } } -} -impl CliAppAsync { - pub async fn run(self, args: Vec) -> Result<(), RpcError> { - let cmd = self - .app - .cli - .cmd - .clone() - .subcommands(self.app.commands.iter().filter_map(|c| c.cli_app())); + pub fn run(self, args: impl IntoIterator) -> Result<(), RpcError> { + let ctx_ty = TypeId::of::(); + let mut cmd = Config::command(); + for (name, handlers) in &self.root_handler.subcommands.0 { + if let (Name(Some(name)), Some(DynHandler::WithCli(handler))) = ( + name, + if let Some(handler) = handlers.get(&Some(ctx_ty)) { + Some(handler) + } else if let Some(handler) = handlers.get(&None) { + Some(handler) + } else { + None + }, + ) { + cmd = cmd.subcommand(handler.cli_command(ctx_ty).name(name)); + } + } let matches = cmd.get_matches_from(args); - let make_ctx_args = (self.app.cli.parser)(&matches)?; - let ctx = (self.make_ctx)(make_ctx_args).await?; - let (parent_method, params, cmd) = self.app.cmd_from_cli_matches(&matches)?; - let display = &cmd - .cli - .as_ref() - .ok_or(yajrc::METHOD_NOT_FOUND_ERROR)? - .display; - let res = (cmd - .implementation - .as_ref() - .ok_or(yajrc::METHOD_NOT_FOUND_ERROR)? - .async_impl)(ctx.clone(), parent_method.clone(), params.clone()) - .await?; - if let Some(display) = display { - display(ctx, parent_method, params, res).map_err(parse_error) - } else { - Ok(()) - } - } -} - -pub struct CliAppSync { - app: CliApp, - make_ctx: Box Result + Send>, -} -impl CliAppSync { - pub fn new< - Cmd: FromArgMatches + CommandFactory + Serialize + DeserializeOwned + Send, - F: FnOnce(Cmd) -> Result + Send + 'static, - >( - make_ctx: F, - commands: Vec>, - ) -> Self { - Self { - app: CliApp::new::(commands), - make_ctx: Box::new(|args| make_ctx(imbl_value::from_value(args).map_err(parse_error)?)), - } - } -} -impl CliAppSync { - pub async fn run(self, args: Vec) -> Result<(), RpcError> { - let cmd = self - .app - .cli - .cmd - .clone() - .subcommands(self.app.commands.iter().filter_map(|c| c.cli_app())); - let matches = cmd.get_matches_from(args); - let make_ctx_args = (self.app.cli.parser)(&matches)?; - let ctx = (self.make_ctx)(make_ctx_args)?; - let (parent_method, params, cmd) = self.app.cmd_from_cli_matches(&matches)?; - let display = &cmd - .cli - .as_ref() - .ok_or(yajrc::METHOD_NOT_FOUND_ERROR)? - .display; - let res = (cmd - .implementation - .as_ref() - .ok_or(yajrc::METHOD_NOT_FOUND_ERROR)? - .sync_impl)(ctx.clone(), parent_method.clone(), params.clone())?; - if let Some(display) = display { - display(ctx, parent_method, params, res).map_err(parse_error) - } else { - Ok(()) - } + let config = Config::from_arg_matches(&matches)?; + let ctx = (self.make_ctx)(config)?; + let root_handler = AnyHandler::new(self.root_handler); + let (method, params) = root_handler.cli_parse(&matches, ctx_ty)?; + let res = root_handler.handle_sync(HandleAnyArgs { + context: ctx.clone().upcast(), + parent_method: Vec::new(), + method: method.clone(), + params: params.clone(), + })?; + root_handler.cli_display( + HandleAnyArgs { + context: ctx.upcast(), + parent_method: Vec::new(), + method, + params, + }, + res, + )?; + Ok(()) } } @@ -285,6 +135,15 @@ pub trait CliContextHttp: crate::Context { } } } +#[async_trait::async_trait] +impl CliContext for T +where + T: CliContextHttp, +{ + async fn call_remote(&self, method: &str, params: Value) -> Result { + ::call_remote(&self, method, params).await + } +} #[async_trait::async_trait] pub trait CliContextSocket: crate::Context { @@ -318,66 +177,55 @@ pub trait CliContextSocket: crate::Context { } } -pub trait RemoteCommand: LeafCommand {} +#[derive(Debug, Default)] +pub struct CallRemote(PhantomData<(RemoteContext, RemoteHandler)>); +impl CallRemote { + pub fn new() -> Self { + Self(PhantomData) + } +} +impl Clone for CallRemote { + fn clone(&self) -> Self { + Self(PhantomData) + } +} #[async_trait::async_trait] -impl AsyncCommand for T +impl Handler + for CallRemote where - T: RemoteCommand + Send + Serialize, - T::Parent: Serialize, - T::Ok: DeserializeOwned, - T::Err: From, - Context: CliContext + Send + 'static, + RemoteContext: IntoContext, + RemoteHandler: Handler, + RemoteHandler::Params: Serialize, + RemoteHandler::InheritedParams: Serialize, + RemoteHandler::Ok: DeserializeOwned, + RemoteHandler::Err: From, { - async fn implementation( - self, - ctx: Context, - parent: ParentInfo, + type Params = RemoteHandler::Params; + type InheritedParams = RemoteHandler::InheritedParams; + type Ok = RemoteHandler::Ok; + type Err = RemoteHandler::Err; + async fn handle_async( + &self, + handle_args: HandleArgs, ) -> Result { - let mut method = parent.method; - method.push(Self::NAME); - Ok(imbl_value::from_value( - ctx.call_remote( - &method.join("."), - combine( - imbl_value::to_value(&self).map_err(invalid_params)?, - imbl_value::to_value(&parent.params).map_err(invalid_params)?, - ) - .map_err(invalid_params)?, + let full_method = handle_args + .parent_method + .into_iter() + .chain(handle_args.method) + .collect::>(); + match handle_args + .context + .call_remote( + &full_method.join("."), + imbl_value::to_value(&Flat(handle_args.params, handle_args.inherited_params)) + .map_err(parse_error)?, ) - .await?, - ) - .map_err(parse_error)?) - } -} - -impl SyncCommand for T -where - T: RemoteCommand + Send + Serialize, - T::Parent: Serialize, - T::Ok: DeserializeOwned, - T::Err: From, - Context: CliContext + Send + 'static, -{ - const BLOCKING: bool = true; - fn implementation( - self, - ctx: Context, - parent: ParentInfo, - ) -> Result { - let mut method = parent.method; - method.push(Self::NAME); - Ok(imbl_value::from_value( - ctx.runtime().block_on( - ctx.call_remote( - &method.join("."), - combine( - imbl_value::to_value(&self).map_err(invalid_params)?, - imbl_value::to_value(&parent.params).map_err(invalid_params)?, - ) - .map_err(invalid_params)?, - ), - )?, - ) - .map_err(parse_error)?) + .await + { + Ok(a) => imbl_value::from_value(a) + .map_err(internal_error) + .map_err(Self::Err::from), + Err(e) => Err(Self::Err::from(e)), + } } } diff --git a/rpc-toolkit/src/context.rs b/rpc-toolkit/src/context.rs index 79720f8..f2c0e43 100644 --- a/rpc-toolkit/src/context.rs +++ b/rpc-toolkit/src/context.rs @@ -5,14 +5,15 @@ use tokio::runtime::Handle; use crate::Handler; -pub trait Context: Any + Send + 'static { +pub trait Context: Any + Send + Sync + 'static { fn runtime(&self) -> Handle { Handle::current() } } #[allow(private_bounds)] -pub trait IntoContext: sealed::Sealed + Any + Send + Sized + 'static { +pub trait IntoContext: sealed::Sealed + Any + Send + Sync + Sized + 'static { + fn runtime(&self) -> Handle; fn type_ids_for + ?Sized>(handler: &H) -> Option>; fn inner_type_id(&self) -> TypeId; fn upcast(self) -> AnyContext; @@ -20,7 +21,10 @@ pub trait IntoContext: sealed::Sealed + Any + Send + Sized + 'static { } impl IntoContext for C { - fn type_ids_for + ?Sized>(handler: &H) -> Option> { + fn runtime(&self) -> Handle { + ::runtime(&self) + } + fn type_ids_for + ?Sized>(_: &H) -> Option> { let mut set = BTreeSet::new(); set.insert(TypeId::of::()); Some(set) @@ -45,7 +49,13 @@ pub enum EitherContext { C2(C2), } impl IntoContext for EitherContext { - fn type_ids_for + ?Sized>(handler: &H) -> Option> { + fn runtime(&self) -> Handle { + match self { + Self::C1(a) => a.runtime(), + Self::C2(a) => a.runtime(), + } + } + fn type_ids_for + ?Sized>(_: &H) -> Option> { let mut set = BTreeSet::new(); set.insert(TypeId::of::()); set.insert(TypeId::of::()); @@ -88,6 +98,9 @@ impl AnyContext { } impl IntoContext for AnyContext { + fn runtime(&self) -> Handle { + self.0.runtime() + } fn type_ids_for + ?Sized>(_: &H) -> Option> { None } diff --git a/rpc-toolkit/src/handler.rs b/rpc-toolkit/src/handler.rs index ab1c498..b9e86f1 100644 --- a/rpc-toolkit/src/handler.rs +++ b/rpc-toolkit/src/handler.rs @@ -1,9 +1,11 @@ use std::any::TypeId; use std::collections::{BTreeMap, BTreeSet, VecDeque}; use std::marker::PhantomData; +use std::ops::Deref; use std::sync::Arc; use clap::{ArgMatches, Command, CommandFactory, FromArgMatches, Parser}; +use futures::Future; use imbl_value::Value; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; @@ -12,11 +14,11 @@ use yajrc::RpcError; use crate::context::{AnyContext, IntoContext}; use crate::util::{combine, internal_error, invalid_params, Flat}; -struct HandleAnyArgs { - context: AnyContext, - parent_method: Vec<&'static str>, - method: VecDeque<&'static str>, - params: Value, +pub(crate) struct HandleAnyArgs { + pub(crate) context: AnyContext, + pub(crate) parent_method: Vec<&'static str>, + pub(crate) method: VecDeque<&'static str>, + pub(crate) params: Value, } impl HandleAnyArgs { fn downcast(self) -> Result, imbl_value::Error> @@ -40,17 +42,31 @@ impl HandleAnyArgs { method, params: imbl_value::from_value(params.clone())?, inherited_params: imbl_value::from_value(params.clone())?, + raw_params: params, }) } } #[async_trait::async_trait] -trait HandleAny { +pub(crate) trait HandleAny: Send + Sync { fn handle_sync(&self, handle_args: HandleAnyArgs) -> Result; - // async fn handle_async(&self, handle_args: HandleAnyArgs) -> Result; + async fn handle_async(&self, handle_args: HandleAnyArgs) -> Result; + fn method_from_dots(&self, method: &str, ctx_ty: TypeId) -> Option>; +} +#[async_trait::async_trait] +impl HandleAny for Arc { + fn handle_sync(&self, handle_args: HandleAnyArgs) -> Result { + self.deref().handle_sync(handle_args) + } + async fn handle_async(&self, handle_args: HandleAnyArgs) -> Result { + self.deref().handle_async(handle_args).await + } + fn method_from_dots(&self, method: &str, ctx_ty: TypeId) -> Option> { + self.deref().method_from_dots(method, ctx_ty) + } } -trait CliBindingsAny { +pub(crate) trait CliBindingsAny { fn cli_command(&self, ctx_ty: TypeId) -> Command; fn cli_parse( &self, @@ -97,11 +113,21 @@ pub trait PrintCliResult: Handler { // } // } +#[derive(Debug)] struct WithCliBindings { _ctx: PhantomData, handler: H, } +impl Clone for WithCliBindings { + fn clone(&self) -> Self { + Self { + _ctx: PhantomData, + handler: self.handler.clone(), + } + } +} +#[async_trait::async_trait] impl Handler for WithCliBindings where Context: IntoContext, @@ -119,6 +145,7 @@ where method, params, inherited_params, + raw_params, }: HandleArgs, ) -> Result { self.handler.handle_sync(HandleArgs { @@ -127,8 +154,31 @@ where method, params, inherited_params, + raw_params, }) } + async fn handle_async( + &self, + HandleArgs { + context, + parent_method, + method, + params, + inherited_params, + raw_params, + }: HandleArgs, + ) -> Result { + self.handler + .handle_async(HandleArgs { + context, + parent_method, + method, + params, + inherited_params, + raw_params, + }) + .await + } } impl CliBindings for WithCliBindings @@ -162,6 +212,7 @@ where method, params, inherited_params, + raw_params, }: HandleArgs, result: Self::Ok, ) -> Result<(), Self::Err> { @@ -172,20 +223,22 @@ where method, params, inherited_params, + raw_params, }, result, ) } } -trait HandleAnyWithCli: HandleAny + CliBindingsAny {} +pub(crate) trait HandleAnyWithCli: HandleAny + CliBindingsAny {} impl HandleAnyWithCli for T {} #[derive(Clone)] -enum DynHandler { +pub(crate) enum DynHandler { WithoutCli(Arc), WithCli(Arc), } +#[async_trait::async_trait] impl HandleAny for DynHandler { fn handle_sync(&self, handle_args: HandleAnyArgs) -> Result { match self { @@ -193,32 +246,91 @@ impl HandleAny for DynHandler { DynHandler::WithCli(h) => h.handle_sync(handle_args), } } -} - -pub struct HandleArgs + ?Sized> { - context: Context, - parent_method: Vec<&'static str>, - method: VecDeque<&'static str>, - params: H::Params, - inherited_params: H::InheritedParams, -} - -pub trait Handler { - type Params; - type InheritedParams; - type Ok; - type Err; - fn handle_sync(&self, handle_args: HandleArgs) -> Result; - fn contexts(&self) -> Option> { - Context::type_ids_for(self) + async fn handle_async(&self, handle_args: HandleAnyArgs) -> Result { + match self { + DynHandler::WithoutCli(h) => h.handle_async(handle_args).await, + DynHandler::WithCli(h) => h.handle_async(handle_args).await, + } + } + fn method_from_dots(&self, method: &str, ctx_ty: TypeId) -> Option> { + match self { + DynHandler::WithoutCli(h) => h.method_from_dots(method, ctx_ty), + DynHandler::WithCli(h) => h.method_from_dots(method, ctx_ty), + } } } -struct AnyHandler { +#[derive(Debug, Clone)] +pub struct HandleArgs + ?Sized> { + pub context: Context, + pub parent_method: Vec<&'static str>, + pub method: VecDeque<&'static str>, + pub params: H::Params, + pub inherited_params: H::InheritedParams, + pub raw_params: Value, +} + +#[async_trait::async_trait] +pub trait Handler: Clone + Send + Sync + 'static { + type Params: Send + Sync; + type InheritedParams: Send + Sync; + type Ok: Send + Sync; + type Err: Send + Sync; + fn handle_sync(&self, handle_args: HandleArgs) -> Result { + handle_args + .context + .runtime() + .block_on(self.handle_async(handle_args)) + } + async fn handle_async( + &self, + handle_args: HandleArgs, + ) -> Result; + async fn handle_async_with_sync( + &self, + handle_args: HandleArgs, + ) -> Result { + self.handle_sync(handle_args) + } + async fn handle_async_with_sync_blocking( + &self, + handle_args: HandleArgs, + ) -> Result { + let s = self.clone(); + handle_args + .context + .runtime() + .spawn_blocking(move || s.handle_sync(handle_args)) + .await + .unwrap() + } + fn contexts(&self) -> Option> { + Context::type_ids_for(self) + } + #[allow(unused_variables)] + fn method_from_dots(&self, method: &str, ctx_ty: TypeId) -> Option> { + if method.is_empty() { + Some(VecDeque::new()) + } else { + None + } + } +} + +pub(crate) struct AnyHandler { _ctx: PhantomData, handler: H, } +impl AnyHandler { + pub(crate) fn new(handler: H) -> Self { + Self { + _ctx: PhantomData, + handler, + } + } +} +#[async_trait::async_trait] impl> HandleAny for AnyHandler where H::Params: DeserializeOwned, @@ -234,11 +346,24 @@ where ) .map_err(internal_error) } + async fn handle_async(&self, handle_args: HandleAnyArgs) -> Result { + imbl_value::to_value( + &self + .handler + .handle_async(handle_args.downcast().map_err(invalid_params)?) + .await?, + ) + .map_err(internal_error) + } + fn method_from_dots(&self, method: &str, ctx_ty: TypeId) -> Option> { + self.handler.method_from_dots(method, ctx_ty) + } } impl> CliBindingsAny for AnyHandler where - H::Params: FromArgMatches + CommandFactory + Serialize + DeserializeOwned, + H: CliBindings, + H::Params: DeserializeOwned, H::InheritedParams: DeserializeOwned, H::Ok: Serialize + DeserializeOwned, RpcError: From, @@ -270,14 +395,15 @@ pub struct NoParams {} pub enum Never {} #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -struct Name(Option<&'static str>); +pub(crate) struct Name(pub(crate) Option<&'static str>); impl<'a> std::borrow::Borrow> for Name { fn borrow(&self) -> &Option<&'a str> { &self.0 } } -struct SubcommandMap(BTreeMap, DynHandler>>); +#[derive(Clone)] +pub(crate) struct SubcommandMap(pub(crate) BTreeMap, DynHandler>>); impl SubcommandMap { fn insert( &mut self, @@ -311,7 +437,7 @@ impl SubcommandMap { pub struct ParentHandler { _phantom: PhantomData<(Params, InheritedParams)>, - subcommands: SubcommandMap, + pub(crate) subcommands: SubcommandMap, } impl ParentHandler { pub fn new() -> Self { @@ -321,24 +447,40 @@ impl ParentHandler { } } } +impl Clone for ParentHandler { + fn clone(&self) -> Self { + Self { + _phantom: PhantomData, + subcommands: self.subcommands.clone(), + } + } +} -struct InheritanceHandler< - Context: IntoContext, - Params, - InheritedParams, - H: Handler, - F: Fn(Params, InheritedParams) -> H::InheritedParams, -> { +struct InheritanceHandler { _phantom: PhantomData<(Context, Params, InheritedParams)>, handler: H, inherit: F, } +impl Clone + for InheritanceHandler +{ + fn clone(&self) -> Self { + Self { + _phantom: PhantomData, + handler: self.handler.clone(), + inherit: self.inherit.clone(), + } + } +} +#[async_trait::async_trait] impl Handler for InheritanceHandler where Context: IntoContext, + Params: Send + Sync + 'static, + InheritedParams: Send + Sync + 'static, H: Handler, - F: Fn(Params, InheritedParams) -> H::InheritedParams, + F: Fn(Params, InheritedParams) -> H::InheritedParams + Send + Sync + Clone + 'static, { type Params = H::Params; type InheritedParams = Flat; @@ -352,6 +494,7 @@ where method, params, inherited_params, + raw_params, }: HandleArgs, ) -> Result { self.handler.handle_sync(HandleArgs { @@ -360,6 +503,27 @@ where method, params, inherited_params: (self.inherit)(inherited_params.0, inherited_params.1), + raw_params, + }) + } + async fn handle_async( + &self, + HandleArgs { + context, + parent_method, + method, + params, + inherited_params, + raw_params, + }: HandleArgs, + ) -> Result { + self.handler.handle_sync(HandleArgs { + context, + parent_method, + method, + params, + inherited_params: (self.inherit)(inherited_params.0, inherited_params.1), + raw_params, }) } } @@ -368,8 +532,10 @@ impl PrintCliResult for InheritanceHandler where Context: IntoContext, + Params: Send + Sync + 'static, + InheritedParams: Send + Sync + 'static, H: Handler + PrintCliResult, - F: Fn(Params, InheritedParams) -> H::InheritedParams, + F: Fn(Params, InheritedParams) -> H::InheritedParams + Send + Sync + Clone + 'static, { fn print( &self, @@ -379,6 +545,7 @@ where method, params, inherited_params, + raw_params, }: HandleArgs, result: Self::Ok, ) -> Result<(), Self::Err> { @@ -389,13 +556,14 @@ where method, params, inherited_params: (self.inherit)(inherited_params.0, inherited_params.1), + raw_params, }, result, ) } } -impl ParentHandler { +impl ParentHandler { pub fn subcommand(mut self, name: &'static str, handler: H) -> Self where Context: IntoContext, @@ -436,7 +604,7 @@ impl ParentHandler { self } } -impl ParentHandler +impl ParentHandler where Params: DeserializeOwned + 'static, InheritedParams: DeserializeOwned + 'static, @@ -453,7 +621,7 @@ where H::Params: FromArgMatches + CommandFactory + Serialize + DeserializeOwned, H::Ok: Serialize + DeserializeOwned, RpcError: From, - F: Fn(Params, InheritedParams) -> H::InheritedParams + 'static, + F: Fn(Params, InheritedParams) -> H::InheritedParams + Send + Sync + Clone + 'static, { self.subcommands.insert( handler.contexts(), @@ -484,7 +652,7 @@ where H::Params: DeserializeOwned, H::Ok: Serialize, RpcError: From, - F: Fn(Params, InheritedParams) -> H::InheritedParams + 'static, + F: Fn(Params, InheritedParams) -> H::InheritedParams + Send + Sync + Clone + 'static, { self.subcommands.insert( handler.contexts(), @@ -507,7 +675,7 @@ where H::Params: FromArgMatches + CommandFactory + Serialize + DeserializeOwned, H::Ok: Serialize + DeserializeOwned, RpcError: From, - F: Fn(Params, InheritedParams) -> H::InheritedParams + 'static, + F: Fn(Params, InheritedParams) -> H::InheritedParams + Send + Sync + Clone + 'static, { self.subcommands.insert( handler.contexts(), @@ -533,7 +701,7 @@ where H::Params: DeserializeOwned, H::Ok: Serialize, RpcError: From, - F: Fn(Params, InheritedParams) -> H::InheritedParams + 'static, + F: Fn(Params, InheritedParams) -> H::InheritedParams + Send + Sync + Clone + 'static, { self.subcommands.insert( handler.contexts(), @@ -551,8 +719,12 @@ where } } -impl Handler - for ParentHandler +#[async_trait::async_trait] +impl< + Context: IntoContext, + Params: Serialize + Send + Sync + 'static, + InheritedParams: Serialize + Send + Sync + 'static, + > Handler for ParentHandler { type Params = Params; type InheritedParams = InheritedParams; @@ -564,9 +736,9 @@ impl Handler context, mut parent_method, mut method, - params, - inherited_params, - }: HandleArgs, + raw_params, + .. + }: HandleArgs, ) -> Result { let cmd = method.pop_front(); if let Some(cmd) = cmd { @@ -577,13 +749,39 @@ impl Handler context: context.upcast(), parent_method, method, - params: imbl_value::to_value(&Flat(params, inherited_params)) - .map_err(invalid_params)?, + params: raw_params, }) } else { Err(yajrc::METHOD_NOT_FOUND_ERROR) } } + async fn handle_async( + &self, + HandleArgs { + context, + mut parent_method, + mut method, + raw_params, + .. + }: HandleArgs, + ) -> Result { + let cmd = method.pop_front(); + if let Some(cmd) = cmd { + parent_method.push(cmd); + } + if let Some((_, sub_handler)) = self.subcommands.get(context.inner_type_id(), cmd) { + sub_handler + .handle_async(HandleAnyArgs { + context: context.upcast(), + parent_method, + method, + params: raw_params, + }) + .await + } else { + Err(yajrc::METHOD_NOT_FOUND_ERROR) + } + } fn contexts(&self) -> Option> { let mut set = BTreeSet::new(); for ctx_ty in self.subcommands.0.values().flat_map(|c| c.keys()) { @@ -591,12 +789,31 @@ impl Handler } Some(set) } + fn method_from_dots(&self, method: &str, ctx_ty: TypeId) -> Option> { + let (head, tail) = if method.is_empty() { + (None, None) + } else { + method + .split_once(".") + .map(|(head, tail)| (Some(head), Some(tail))) + .unwrap_or((Some(method), None)) + }; + let (Name(name), h) = self.subcommands.get(ctx_ty, head)?; + let mut res = VecDeque::new(); + if let Some(name) = name { + res.push_back(name); + } + if let Some(tail) = tail { + res.append(&mut h.method_from_dots(tail, ctx_ty)?); + } + Some(res) + } } impl CliBindings for ParentHandler where - Params: FromArgMatches + CommandFactory + Serialize, - InheritedParams: Serialize, + Params: FromArgMatches + CommandFactory + Serialize + Send + Sync + 'static, + InheritedParams: Serialize + Send + Sync + 'static, { fn cli_command(&self, ctx_ty: TypeId) -> Command { let mut base = Params::command(); @@ -648,8 +865,8 @@ where context, mut parent_method, mut method, - params, - inherited_params, + raw_params, + .. }: HandleArgs, result: Self::Ok, ) -> Result<(), Self::Err> { @@ -665,8 +882,7 @@ where context, parent_method, method, - params: imbl_value::to_value(&Flat(params, inherited_params)) - .map_err(invalid_params)?, + params: raw_params, }, result, ) @@ -679,19 +895,61 @@ where pub struct FromFn { _phantom: PhantomData<(T, E, Args)>, function: F, + blocking: bool, +} +impl Clone for FromFn { + fn clone(&self) -> Self { + Self { + _phantom: PhantomData, + function: self.function.clone(), + blocking: self.blocking, + } + } } pub fn from_fn(function: F) -> FromFn { FromFn { function, _phantom: PhantomData, + blocking: false, } } +pub fn from_fn_blocking(function: F) -> FromFn { + FromFn { + function, + _phantom: PhantomData, + blocking: true, + } +} + +pub struct FromFnAsync { + _phantom: PhantomData<(Fut, T, E, Args)>, + function: F, +} +impl Clone for FromFnAsync { + fn clone(&self) -> Self { + Self { + _phantom: PhantomData, + function: self.function.clone(), + } + } +} + +pub fn from_fn_async(function: F) -> FromFnAsync { + FromFnAsync { + function, + _phantom: PhantomData, + } +} + +#[async_trait::async_trait] impl Handler for FromFn where Context: IntoContext, - F: Fn() -> Result, + F: Fn() -> Result + Send + Sync + Clone + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, { type Params = NoParams; type InheritedParams = NoParams; @@ -700,12 +958,42 @@ where fn handle_sync(&self, _: HandleArgs) -> Result { (self.function)() } + async fn handle_async( + &self, + handle_args: HandleArgs, + ) -> Result { + if self.blocking { + self.handle_async_with_sync_blocking(handle_args).await + } else { + self.handle_async_with_sync(handle_args).await + } + } +} +#[async_trait::async_trait] +impl Handler for FromFnAsync +where + Context: IntoContext, + F: Fn() -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + Send + Sync + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, +{ + type Params = NoParams; + type InheritedParams = NoParams; + type Ok = T; + type Err = E; + async fn handle_async(&self, _: HandleArgs) -> Result { + (self.function)().await + } } +#[async_trait::async_trait] impl Handler for FromFn where Context: IntoContext, - F: Fn(Context) -> Result, + F: Fn(Context) -> Result + Send + Sync + Clone + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, { type Params = NoParams; type InheritedParams = NoParams; @@ -714,12 +1002,46 @@ where fn handle_sync(&self, handle_args: HandleArgs) -> Result { (self.function)(handle_args.context) } + async fn handle_async( + &self, + handle_args: HandleArgs, + ) -> Result { + if self.blocking { + self.handle_async_with_sync_blocking(handle_args).await + } else { + self.handle_async_with_sync(handle_args).await + } + } } +#[async_trait::async_trait] +impl Handler for FromFnAsync +where + Context: IntoContext, + F: Fn(Context) -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + Send + Sync + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, +{ + type Params = NoParams; + type InheritedParams = NoParams; + type Ok = T; + type Err = E; + async fn handle_async( + &self, + handle_args: HandleArgs, + ) -> Result { + (self.function)(handle_args.context).await + } +} + +#[async_trait::async_trait] impl Handler for FromFn where Context: IntoContext, - F: Fn(Context, Params) -> Result, - Params: DeserializeOwned, + F: Fn(Context, Params) -> Result + Send + Sync + Clone + 'static, + Params: DeserializeOwned + Send + Sync + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, { type Params = Params; type InheritedParams = NoParams; @@ -731,14 +1053,53 @@ where } = handle_args; (self.function)(context, params) } + async fn handle_async( + &self, + handle_args: HandleArgs, + ) -> Result { + if self.blocking { + self.handle_async_with_sync_blocking(handle_args).await + } else { + self.handle_async_with_sync(handle_args).await + } + } } +#[async_trait::async_trait] +impl Handler + for FromFnAsync +where + Context: IntoContext, + F: Fn(Context, Params) -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + Send + Sync + 'static, + Params: DeserializeOwned + Send + Sync + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, +{ + type Params = Params; + type InheritedParams = NoParams; + type Ok = T; + type Err = E; + async fn handle_async( + &self, + handle_args: HandleArgs, + ) -> Result { + let HandleArgs { + context, params, .. + } = handle_args; + (self.function)(context, params).await + } +} + +#[async_trait::async_trait] impl Handler for FromFn where Context: IntoContext, - F: Fn(Context, Params, InheritedParams) -> Result, - Params: DeserializeOwned, - InheritedParams: DeserializeOwned, + F: Fn(Context, Params, InheritedParams) -> Result + Send + Sync + Clone + 'static, + Params: DeserializeOwned + Send + Sync + 'static, + InheritedParams: DeserializeOwned + Send + Sync + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, { type Params = Params; type InheritedParams = InheritedParams; @@ -753,44 +1114,43 @@ where } = handle_args; (self.function)(context, params, inherited_params) } + async fn handle_async( + &self, + handle_args: HandleArgs, + ) -> Result { + if self.blocking { + self.handle_async_with_sync_blocking(handle_args).await + } else { + self.handle_async_with_sync(handle_args).await + } + } +} +#[async_trait::async_trait] +impl Handler + for FromFnAsync +where + Context: IntoContext, + F: Fn(Context, Params, InheritedParams) -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + Send + Sync + 'static, + Params: DeserializeOwned + Send + Sync + 'static, + InheritedParams: DeserializeOwned + Send + Sync + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, +{ + type Params = Params; + type InheritedParams = InheritedParams; + type Ok = T; + type Err = E; + async fn handle_async( + &self, + handle_args: HandleArgs, + ) -> Result { + let HandleArgs { + context, + params, + inherited_params, + .. + } = handle_args; + (self.function)(context, params, inherited_params).await + } } - -#[derive(Parser)] -#[command(about = "this is db stuff")] -struct DbParams {} - -// Server::new( -// ParentCommand::new() -// .subcommand("foo", from_fn(foo)) -// .subcommand("db", -// ParentCommand::new::() -// .subcommand("dump", from_fn(dump)) -// ) -// ) - -// Server::new() -// .handle( -// "db", -// with_description("Description maybe?") -// .handle("dump", from_fn(dump_route)) -// ) -// .handle( -// "server", -// no_description() -// .handle("version", from_fn(version)) -// ) - -// #[derive(clap::Parser)] -// struct DumpParams { -// test: Option -// } - -// fn dump_route(context: Context, param: Param) -> Result { -// Ok(json!({ -// "db": {} -// })) -// } - -// fn version() -> &'static str { -// "1.0.0" -// } diff --git a/rpc-toolkit/src/lib.rs b/rpc-toolkit/src/lib.rs index fec86b5..bcbf617 100644 --- a/rpc-toolkit/src/lib.rs +++ b/rpc-toolkit/src/lib.rs @@ -1,4 +1,4 @@ -// pub use cli::*; +pub use cli::*; // pub use command::*; pub use context::*; pub use handler::*; @@ -24,15 +24,11 @@ pub use handler::*; /// /// See also: [arg](rpc_toolkit_macro::arg), [context](rpc_toolkit_macro::context) pub use rpc_toolkit_macro::command; -// pub use server::*; +pub use server::*; pub use {clap, futures, hyper, reqwest, serde, serde_json, tokio, url, yajrc}; -// mod cli; -// mod command; -mod handler; -// pub mod command_helpers; +mod cli; mod context; -// mod metadata; -// pub mod rpc_server_helpers; -// mod server; +mod handler; +mod server; mod util; diff --git a/rpc-toolkit/src/server/http.rs b/rpc-toolkit/src/server/http.rs index 29bdee6..ab40aeb 100644 --- a/rpc-toolkit/src/server/http.rs +++ b/rpc-toolkit/src/server/http.rs @@ -19,7 +19,7 @@ pub trait Middleware { async fn process_rpc_request( &self, prev: Self::ProcessHttpRequestResult, - metadata: &Context::Metadata, + // metadata: &Context::Metadata, req: &mut RpcRequest, ) -> Result; type ProcessRpcResponseResult; diff --git a/rpc-toolkit/src/server/mod.rs b/rpc-toolkit/src/server/mod.rs index 2ee1ca1..5920717 100644 --- a/rpc-toolkit/src/server/mod.rs +++ b/rpc-toolkit/src/server/mod.rs @@ -1,13 +1,13 @@ -use std::borrow::Cow; +use std::any::TypeId; use std::sync::Arc; use futures::future::{join_all, BoxFuture}; use futures::{Future, FutureExt, Stream, StreamExt}; use imbl_value::Value; -use yajrc::{AnyParams, AnyRpcMethod, RpcError, RpcMethod}; +use yajrc::{RpcError, RpcMethod}; use crate::util::{invalid_request, JobRunner}; -use crate::DynCommand; +use crate::{AnyHandler, HandleAny, HandleAnyArgs, IntoContext, ParentHandler}; type GenericRpcMethod = yajrc::GenericRpcMethod; type RpcRequest = yajrc::RpcRequest; @@ -20,41 +20,21 @@ mod socket; pub use http::*; pub use socket::*; -impl DynCommand { - fn cmd_from_method( - &self, - method: &[&str], - parent_method: Vec<&'static str>, - ) -> Result<(Vec<&'static str>, &DynCommand), RpcError> { - let mut ret_method = parent_method; - ret_method.push(self.name); - if let Some((cmd, rest)) = method.split_first() { - self.subcommands - .iter() - .find(|c| c.name == *cmd) - .ok_or(yajrc::METHOD_NOT_FOUND_ERROR)? - .cmd_from_method(rest, ret_method) - } else { - Ok((ret_method, self)) - } - } -} - pub struct Server { - commands: Vec>, make_ctx: Arc BoxFuture<'static, Result> + Send + Sync>, + root_handler: Arc>, } impl Server { pub fn new< - F: Fn() -> Fut + Send + Sync + 'static, + MakeCtx: Fn() -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, >( - commands: Vec>, - make_ctx: F, + make_ctx: MakeCtx, + root_handler: ParentHandler, ) -> Self { Server { - commands, make_ctx: Arc::new(move || make_ctx().boxed()), + root_handler: Arc::new(AnyHandler::new(root_handler)), } } @@ -63,30 +43,22 @@ impl Server { method: &str, params: Value, ) -> impl Future> + Send + 'static { - let from_self = (|| { - let method: Vec<_> = method.split(".").collect(); - let (cmd, rest) = method.split_first().ok_or(yajrc::METHOD_NOT_FOUND_ERROR)?; - let (method, cmd) = self - .commands - .iter() - .find(|c| c.name == *cmd) - .ok_or(yajrc::METHOD_NOT_FOUND_ERROR)? - .cmd_from_method(rest, Vec::new())?; - Ok::<_, RpcError>(( - cmd.implementation - .as_ref() - .ok_or(yajrc::METHOD_NOT_FOUND_ERROR)? - .async_impl - .clone(), - self.make_ctx.clone(), - method, - params, - )) - })(); + let (make_ctx, root_handler, method) = ( + self.make_ctx.clone(), + self.root_handler.clone(), + self.root_handler + .method_from_dots(method, TypeId::of::()), + ); async move { - let (implementation, make_ctx, method, params) = from_self?; - implementation(make_ctx().await?, method, params).await + root_handler + .handle_async(HandleAnyArgs { + context: make_ctx().await?.upcast(), + parent_method: Vec::new(), + method: method.ok_or_else(|| yajrc::METHOD_NOT_FOUND_ERROR)?, + params, + }) + .await } } diff --git a/rpc-toolkit/tests/handler.rs b/rpc-toolkit/tests/handler.rs index 5b0eaeb..bac9e77 100644 --- a/rpc-toolkit/tests/handler.rs +++ b/rpc-toolkit/tests/handler.rs @@ -69,10 +69,8 @@ impl Context for CliContext { // } fn make_api() -> ParentHandler { - ParentHandler::new().subcommand_no_cli( - Some("hello"), - from_fn(|_: CliContext| Ok::<_, RpcError>("world")), - ) + ParentHandler::new() + .subcommand_no_cli("hello", from_fn(|_: CliContext| Ok::<_, RpcError>("world"))) } pub fn internal_error(e: impl Display) -> RpcError {