diff --git a/src/context.rs b/src/context.rs index b1b2054..d7b1462 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,7 +1,9 @@ -use tokio::runtime::Handle; +use std::sync::Arc; + +use tokio::runtime::Runtime; pub trait Context: Send + Sync + 'static { - fn runtime(&self) -> Handle { - Handle::current() + fn runtime(&self) -> Option> { + None } } diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 99de2cf..796deeb 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -213,10 +213,11 @@ pub trait HandlerFor: &self, handle_args: HandlerArgsFor, ) -> Result { - handle_args - .context - .runtime() - .block_on(self.handle_async(handle_args)) + if let Some(rt) = handle_args.context.runtime() { + rt.block_on(self.handle_async(handle_args)) + } else { + tokio::runtime::Handle::current().block_on(self.handle_async(handle_args)) + } } fn handle_async( &self, @@ -234,12 +235,14 @@ pub trait HandlerFor: ) -> impl Future> + Send + 'a { async move { let s = self.clone(); - handle_args - .context - .runtime() - .spawn_blocking(move || s.handle_sync(handle_args)) - .await - .unwrap() + if let Some(rt) = handle_args.context.runtime() { + rt.spawn_blocking(move || s.handle_sync(handle_args)).await + } else { + tokio::runtime::Handle::current() + .spawn_blocking(move || s.handle_sync(handle_args)) + .await + } + .unwrap() } } #[allow(unused_variables)] diff --git a/tests/handler.rs b/tests/handler.rs index 4e2ce12..9e89d0e 100644 --- a/tests/handler.rs +++ b/tests/handler.rs @@ -48,16 +48,19 @@ impl CliConfig { struct CliContextSeed { host: PathBuf, - rt: OnceCell, + rt: OnceCell>, } #[derive(Clone)] struct CliContext(Arc); impl Context for CliContext { - fn runtime(&self) -> Handle { + fn runtime(&self) -> Option> { if self.0.rt.get().is_none() { - self.0.rt.set(Runtime::new().unwrap()).unwrap(); + let rt = Arc::new(Runtime::new().unwrap()); + self.0.rt.set(rt.clone()).unwrap_or_default(); + Some(rt) + } else { + self.0.rt.get().cloned() } - self.0.rt.get().unwrap().handle().clone() } }