add Extra to CallRemoteHandler

This commit is contained in:
Aiden McClelland
2024-05-03 13:49:31 -06:00
parent 5279528a97
commit 33229337a4

View File

@@ -13,7 +13,7 @@ use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader
use url::Url; use url::Url;
use yajrc::{Id, RpcError}; use yajrc::{Id, RpcError};
use crate::util::{internal_error, parse_error, PhantomData}; use crate::util::{internal_error, invalid_params, parse_error, without, Flat, PhantomData};
use crate::{ use crate::{
AnyHandler, CliBindingsAny, DynHandler, Empty, HandleAny, HandleAnyArgs, Handler, HandlerArgs, AnyHandler, CliBindingsAny, DynHandler, Empty, HandleAny, HandleAnyArgs, Handler, HandlerArgs,
HandlerArgsFor, HandlerTypes, IntoContext, Name, ParentHandler, PrintCliResult, HandlerArgsFor, HandlerTypes, IntoContext, Name, ParentHandler, PrintCliResult,
@@ -172,11 +172,11 @@ pub async fn call_remote_socket(
.result .result
} }
struct CallRemoteHandler<Context, RemoteHandler> { pub struct CallRemoteHandler<Context, RemoteHandler, Extra = Empty> {
_phantom: PhantomData<Context>, _phantom: PhantomData<(Context, Extra)>,
handler: RemoteHandler, handler: RemoteHandler,
} }
impl<Context, RemoteHandler> CallRemoteHandler<Context, RemoteHandler> { impl<Context, RemoteHandler, Extra> CallRemoteHandler<Context, RemoteHandler, Extra> {
pub fn new(handler: RemoteHandler) -> Self { pub fn new(handler: RemoteHandler) -> Self {
Self { Self {
_phantom: PhantomData::new(), _phantom: PhantomData::new(),
@@ -184,7 +184,9 @@ impl<Context, RemoteHandler> CallRemoteHandler<Context, RemoteHandler> {
} }
} }
} }
impl<Context, RemoteHandler: Clone> Clone for CallRemoteHandler<Context, RemoteHandler> { impl<Context, RemoteHandler: Clone, Extra> Clone
for CallRemoteHandler<Context, RemoteHandler, Extra>
{
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
_phantom: PhantomData::new(), _phantom: PhantomData::new(),
@@ -192,34 +194,39 @@ impl<Context, RemoteHandler: Clone> Clone for CallRemoteHandler<Context, RemoteH
} }
} }
} }
impl<Context, RemoteHandler> std::fmt::Debug for CallRemoteHandler<Context, RemoteHandler> { impl<Context, RemoteHandler, Extra> std::fmt::Debug
for CallRemoteHandler<Context, RemoteHandler, Extra>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("CallRemoteHandler").finish() f.debug_tuple("CallRemoteHandler").finish()
} }
} }
impl<Context, RemoteHandler> HandlerTypes for CallRemoteHandler<Context, RemoteHandler> impl<Context, RemoteHandler, Extra> HandlerTypes
for CallRemoteHandler<Context, RemoteHandler, Extra>
where where
RemoteHandler: HandlerTypes, RemoteHandler: HandlerTypes,
RemoteHandler::Params: Serialize, RemoteHandler::Params: Serialize,
RemoteHandler::InheritedParams: Serialize, RemoteHandler::InheritedParams: Serialize,
RemoteHandler::Ok: DeserializeOwned, RemoteHandler::Ok: DeserializeOwned,
RemoteHandler::Err: From<RpcError>, RemoteHandler::Err: From<RpcError>,
Extra: Send + Sync + 'static,
{ {
type Params = RemoteHandler::Params; type Params = Flat<RemoteHandler::Params, Extra>;
type InheritedParams = RemoteHandler::InheritedParams; type InheritedParams = RemoteHandler::InheritedParams;
type Ok = RemoteHandler::Ok; type Ok = RemoteHandler::Ok;
type Err = RemoteHandler::Err; type Err = RemoteHandler::Err;
} }
impl<Context, RemoteHandler> Handler for CallRemoteHandler<Context, RemoteHandler> impl<Context, RemoteHandler, Extra> Handler for CallRemoteHandler<Context, RemoteHandler, Extra>
where where
Context: CallRemote<RemoteHandler::Context>, Context: CallRemote<RemoteHandler::Context, Extra>,
RemoteHandler: Handler, RemoteHandler: Handler,
RemoteHandler::Params: Serialize, RemoteHandler::Params: Serialize,
RemoteHandler::InheritedParams: Serialize, RemoteHandler::InheritedParams: Serialize,
RemoteHandler::Ok: DeserializeOwned, RemoteHandler::Ok: DeserializeOwned,
RemoteHandler::Err: From<RpcError>, RemoteHandler::Err: From<RpcError>,
Extra: Serialize + Send + Sync + 'static,
{ {
type Context = Context; type Context = Context;
async fn handle_async( async fn handle_async(
@@ -235,8 +242,9 @@ where
.context .context
.call_remote( .call_remote(
&full_method.join("."), &full_method.join("."),
handle_args.raw_params.clone(), without(handle_args.raw_params.clone(), &handle_args.params.1)
Empty {}, .map_err(invalid_params)?,
handle_args.params.1,
) )
.await .await
{ {
@@ -247,7 +255,8 @@ where
} }
} }
} }
impl<Context, RemoteHandler> PrintCliResult for CallRemoteHandler<Context, RemoteHandler> impl<Context, RemoteHandler, Extra> PrintCliResult
for CallRemoteHandler<Context, RemoteHandler, Extra>
where where
Context: CallRemote<RemoteHandler::Context>, Context: CallRemote<RemoteHandler::Context>,
RemoteHandler: PrintCliResult<Context = Context>, RemoteHandler: PrintCliResult<Context = Context>,
@@ -255,6 +264,7 @@ where
RemoteHandler::InheritedParams: Serialize, RemoteHandler::InheritedParams: Serialize,
RemoteHandler::Ok: DeserializeOwned, RemoteHandler::Ok: DeserializeOwned,
RemoteHandler::Err: From<RpcError>, RemoteHandler::Err: From<RpcError>,
Extra: Send + Sync + 'static,
{ {
type Context = Context; type Context = Context;
fn print( fn print(
@@ -274,7 +284,7 @@ where
context, context,
parent_method, parent_method,
method, method,
params, params: params.0,
inherited_params, inherited_params,
raw_params, raw_params,
}, },