From 068db905ee38a7da97cc4a43b806409204e73723 Mon Sep 17 00:00:00 2001 From: Aiden McClelland Date: Tue, 4 Nov 2025 15:58:10 -0700 Subject: [PATCH] from_fn_async_local --- src/handler/from_fn.rs | 328 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 328 insertions(+) diff --git a/src/handler/from_fn.rs b/src/handler/from_fn.rs index 1311805..ac875a9 100644 --- a/src/handler/from_fn.rs +++ b/src/handler/from_fn.rs @@ -218,6 +218,100 @@ where } } +pub struct FromFnAsyncLocal { + _phantom: PhantomData<(Fut, T, E, Args)>, + function: F, + metadata: OrdMap<&'static str, Value>, +} +impl FromFnAsyncLocal { + pub fn with_metadata(mut self, key: &'static str, value: Value) -> Self { + self.metadata.insert(key, value); + self + } +} +impl Clone for FromFnAsyncLocal { + fn clone(&self) -> Self { + Self { + _phantom: PhantomData::new(), + function: self.function.clone(), + metadata: self.metadata.clone(), + } + } +} +impl std::fmt::Debug for FromFnAsyncLocal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FromFnAsyncLocal").finish() + } +} +impl PrintCliResult for FromFnAsyncLocal +where + Context: crate::Context, + Self: HandlerTypes, + ::Ok: Display, +{ + fn print(&self, _: HandlerArgsFor, result: Self::Ok) -> Result<(), Self::Err> { + Ok(println!("{result}")) + } +} +impl CliBindings for FromFnAsyncLocal +where + Context: crate::Context, + Self: HandlerTypes, + Self::Params: CommandFactory + FromArgMatches + Serialize, + Self: PrintCliResult, +{ + fn cli_command(&self) -> clap::Command { + Self::Params::command() + } + fn cli_parse( + &self, + matches: &clap::ArgMatches, + ) -> Result<(VecDeque<&'static str>, Value), clap::Error> { + Self::Params::from_arg_matches(matches).and_then(|a| { + Ok(( + VecDeque::new(), + imbl_value::to_value(&a) + .map_err(|e| clap::Error::raw(clap::error::ErrorKind::ValueValidation, e))?, + )) + }) + } + fn cli_display( + &self, + HandlerArgs { + context, + parent_method, + method, + params, + inherited_params, + raw_params, + }: HandlerArgsFor, + result: Self::Ok, + ) -> Result<(), Self::Err> { + self.print( + HandlerArgs { + context, + parent_method, + method, + params, + inherited_params, + raw_params, + }, + result, + ) + } +} + +pub fn from_fn_async_local(function: F) -> FromFnAsyncLocal +where + FromFnAsyncLocal: HandlerTypes, +{ + FromFnAsyncLocal { + function, + _phantom: PhantomData::new(), + metadata: OrdMap::new(), + } +} + impl HandlerTypes for FromFn> where @@ -629,3 +723,237 @@ where self.metadata.clone() } } + +impl HandlerTypes + for FromFnAsyncLocal> +where + F: Fn(HandlerArgs) -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, + Context: crate::Context, + Params: Send + Sync, + InheritedParams: Send + Sync, +{ + type Params = Params; + type InheritedParams = InheritedParams; + type Ok = T; + type Err = E; +} + +impl HandlerFor + for FromFnAsyncLocal> +where + F: Fn(HandlerArgs) -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, + Context: crate::Context, + Params: Send + Sync + 'static, + InheritedParams: Send + Sync + 'static, +{ + fn handle_sync( + &self, + handle_args: HandlerArgsFor, + ) -> Result { + let local = tokio::task::LocalSet::new(); + if let Some(rt) = handle_args.context.runtime() { + local.block_on(&*rt, (self.function)(handle_args)) + } else { + tokio::runtime::Handle::current().block_on(local.run_until((self.function)(handle_args))) + } + } + async fn handle_async( + &self, + handle_args: HandlerArgsFor, + ) -> Result { + self.handle_async_with_sync_blocking(handle_args).await + } + fn metadata(&self, _: VecDeque<&'static str>) -> OrdMap<&'static str, Value> { + self.metadata.clone() + } +} + +impl HandlerTypes for FromFnAsyncLocal +where + F: Fn() -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, +{ + type Params = Empty; + type InheritedParams = Empty; + type Ok = T; + type Err = E; +} + +impl HandlerFor for FromFnAsyncLocal +where + Context: crate::Context, + F: Fn() -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, +{ + fn handle_sync( + &self, + handle_args: HandlerArgsFor, + ) -> Result { + let local = tokio::task::LocalSet::new(); + if let Some(rt) = handle_args.context.runtime() { + local.block_on(&*rt, (self.function)()) + } else { + tokio::runtime::Handle::current().block_on(local.run_until((self.function)())) + } + } + async fn handle_async( + &self, + handle_args: HandlerArgsFor, + ) -> Result { + self.handle_async_with_sync_blocking(handle_args).await + } + fn metadata(&self, _: VecDeque<&'static str>) -> OrdMap<&'static str, Value> { + self.metadata.clone() + } +} + +impl HandlerTypes for FromFnAsyncLocal +where + Context: crate::Context, + F: Fn(Context) -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, +{ + type Params = Empty; + type InheritedParams = Empty; + type Ok = T; + type Err = E; +} + +impl HandlerFor for FromFnAsyncLocal +where + Context: crate::Context, + F: Fn(Context) -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, +{ + fn handle_sync( + &self, + handle_args: HandlerArgsFor, + ) -> Result { + let local = tokio::task::LocalSet::new(); + if let Some(rt) = handle_args.context.runtime() { + local.block_on(&*rt, (self.function)(handle_args.context)) + } else { + tokio::runtime::Handle::current().block_on(local.run_until((self.function)(handle_args.context))) + } + } + async fn handle_async( + &self, + handle_args: HandlerArgsFor, + ) -> Result { + self.handle_async_with_sync_blocking(handle_args).await + } + fn metadata(&self, _: VecDeque<&'static str>) -> OrdMap<&'static str, Value> { + self.metadata.clone() + } +} + +impl HandlerTypes for FromFnAsyncLocal +where + Context: crate::Context, + F: Fn(Context, Params) -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + 'static, + Params: DeserializeOwned + Send + Sync + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, +{ + type Params = Params; + type InheritedParams = Empty; + type Ok = T; + type Err = E; +} + +impl HandlerFor + for FromFnAsyncLocal +where + Context: crate::Context, + F: Fn(Context, Params) -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + 'static, + Params: DeserializeOwned + Send + Sync + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, +{ + fn handle_sync( + &self, + handle_args: HandlerArgsFor, + ) -> Result { + let local = tokio::task::LocalSet::new(); + if let Some(rt) = handle_args.context.runtime() { + local.block_on(&*rt, (self.function)(handle_args.context, handle_args.params)) + } else { + tokio::runtime::Handle::current().block_on(local.run_until((self.function)(handle_args.context, handle_args.params))) + } + } + async fn handle_async( + &self, + handle_args: HandlerArgsFor, + ) -> Result { + self.handle_async_with_sync_blocking(handle_args).await + } + fn metadata(&self, _: VecDeque<&'static str>) -> OrdMap<&'static str, Value> { + self.metadata.clone() + } +} + +impl HandlerTypes + for FromFnAsyncLocal +where + Context: crate::Context, + F: Fn(Context, Params, InheritedParams) -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + 'static, + Params: DeserializeOwned + Send + Sync + 'static, + InheritedParams: Send + Sync + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, +{ + type Params = Params; + type InheritedParams = InheritedParams; + type Ok = T; + type Err = E; +} + +impl HandlerFor + for FromFnAsyncLocal +where + Context: crate::Context, + F: Fn(Context, Params, InheritedParams) -> Fut + Send + Sync + Clone + 'static, + Fut: Future> + 'static, + Params: DeserializeOwned + Send + Sync + 'static, + InheritedParams: Send + Sync + 'static, + T: Send + Sync + 'static, + E: Send + Sync + 'static, +{ + fn handle_sync( + &self, + handle_args: HandlerArgsFor, + ) -> Result { + let local = tokio::task::LocalSet::new(); + if let Some(rt) = handle_args.context.runtime() { + local.block_on(&*rt, (self.function)(handle_args.context, handle_args.params, handle_args.inherited_params)) + } else { + tokio::runtime::Handle::current().block_on(local.run_until((self.function)(handle_args.context, handle_args.params, handle_args.inherited_params))) + } + } + async fn handle_async( + &self, + handle_args: HandlerArgsFor, + ) -> Result { + self.handle_async_with_sync_blocking(handle_args).await + } + fn metadata(&self, _: VecDeque<&'static str>) -> OrdMap<&'static str, Value> { + self.metadata.clone() + } +}