allow concurrency in service actor (#2592)

This commit is contained in:
Aiden McClelland
2024-04-08 11:53:35 -06:00
committed by GitHub
parent 75ff541aec
commit e41f8f1d0f
16 changed files with 535 additions and 129 deletions

View File

@@ -0,0 +1,60 @@
use futures::future::BoxFuture;
use futures::{Future, FutureExt};
use tokio::sync::mpsc;
#[derive(Clone)]
pub struct BackgroundJobQueue(mpsc::UnboundedSender<BoxFuture<'static, ()>>);
impl BackgroundJobQueue {
pub fn new() -> (Self, BackgroundJobRunner) {
let (send, recv) = mpsc::unbounded_channel();
(
Self(send),
BackgroundJobRunner {
recv,
jobs: Vec::new(),
},
)
}
pub fn add_job(&self, fut: impl Future<Output = ()> + Send + 'static) {
let _ = self.0.send(fut.boxed());
}
}
pub struct BackgroundJobRunner {
recv: mpsc::UnboundedReceiver<BoxFuture<'static, ()>>,
jobs: Vec<BoxFuture<'static, ()>>,
}
impl BackgroundJobRunner {
pub fn is_empty(&self) -> bool {
self.recv.is_empty() && self.jobs.is_empty()
}
}
impl Future for BackgroundJobRunner {
type Output = ();
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
while let std::task::Poll::Ready(Some(job)) = self.recv.poll_recv(cx) {
self.jobs.push(job);
}
let complete = self
.jobs
.iter_mut()
.enumerate()
.filter_map(|(i, f)| match f.poll_unpin(cx) {
std::task::Poll::Pending => None,
std::task::Poll::Ready(_) => Some(i),
})
.collect::<Vec<_>>();
for idx in complete.into_iter().rev() {
#[allow(clippy::let_underscore_future)]
let _ = self.jobs.swap_remove(idx);
}
if self.jobs.is_empty() && self.recv.is_closed() {
std::task::Poll::Ready(())
} else {
std::task::Poll::Pending
}
}
}

View File

@@ -0,0 +1,208 @@
use std::any::Any;
use std::sync::Arc;
use std::time::Duration;
use futures::future::{ready, BoxFuture};
use futures::{Future, FutureExt, TryFutureExt};
use helpers::NonDetachingJoinHandle;
use tokio::sync::{mpsc, oneshot};
use crate::prelude::*;
use crate::util::actor::background::{BackgroundJobQueue, BackgroundJobRunner};
use crate::util::actor::{Actor, ConflictFn, Handler, PendingMessageStrategy, Request};
#[pin_project::pin_project]
struct ConcurrentRunner<A> {
actor: A,
shutdown: Option<oneshot::Receiver<()>>,
waiting: Vec<Request<A>>,
recv: mpsc::UnboundedReceiver<Request<A>>,
handlers: Vec<(
Arc<ConflictFn<A>>,
oneshot::Sender<Box<dyn Any + Send>>,
BoxFuture<'static, Box<dyn Any + Send>>,
)>,
queue: BackgroundJobQueue,
#[pin]
bg_runner: BackgroundJobRunner,
}
impl<A: Actor + Clone> Future for ConcurrentRunner<A> {
type Output = ();
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let mut this = self.project();
*this.shutdown = this.shutdown.take().and_then(|mut s| {
if s.poll_unpin(cx).is_pending() {
Some(s)
} else {
None
}
});
if this.shutdown.is_some() {
while let std::task::Poll::Ready(Some((msg, reply))) = this.recv.poll_recv(cx) {
if this.handlers.iter().any(|(f, _, _)| f(&*msg)) {
this.waiting.push((msg, reply));
} else {
let mut actor = this.actor.clone();
let queue = this.queue.clone();
this.handlers.push((
msg.conflicts_with(),
reply,
async move { msg.handle_with(&mut actor, &queue).await }.boxed(),
))
}
}
}
// handlers
while {
let mut cont = false;
let complete = this
.handlers
.iter_mut()
.enumerate()
.filter_map(|(i, (_, _, f))| match f.poll_unpin(cx) {
std::task::Poll::Pending => None,
std::task::Poll::Ready(res) => Some((i, res)),
})
.collect::<Vec<_>>();
for (idx, res) in complete.into_iter().rev() {
#[allow(clippy::let_underscore_future)]
let (f, reply, _) = this.handlers.swap_remove(idx);
let _ = reply.send(res);
// TODO: replace with Vec::extract_if once stable
if this.shutdown.is_some() {
let mut i = 0;
while i < this.waiting.len() {
if f(&*this.waiting[i].0)
&& !this.handlers.iter().any(|(f, _, _)| f(&*this.waiting[i].0))
{
let (msg, reply) = this.waiting.remove(i);
let mut actor = this.actor.clone();
let queue = this.queue.clone();
this.handlers.push((
msg.conflicts_with(),
reply,
async move { msg.handle_with(&mut actor, &queue).await }.boxed(),
));
cont = true;
} else {
i += 1;
}
}
}
}
cont
} {}
let _ = this.bg_runner.as_mut().poll(cx);
if this.waiting.is_empty() && this.handlers.is_empty() && this.bg_runner.is_empty() {
std::task::Poll::Ready(())
} else {
std::task::Poll::Pending
}
}
}
pub struct ConcurrentActor<A: Actor + Clone> {
shutdown: oneshot::Sender<()>,
runtime: NonDetachingJoinHandle<()>,
messenger: mpsc::UnboundedSender<Request<A>>,
}
impl<A: Actor + Clone> ConcurrentActor<A> {
pub fn new(mut actor: A) -> Self {
let (shutdown_send, shutdown_recv) = oneshot::channel();
let (messenger_send, messenger_recv) = mpsc::unbounded_channel::<Request<A>>();
let runtime = NonDetachingJoinHandle::from(tokio::spawn(async move {
let (queue, runner) = BackgroundJobQueue::new();
actor.init(&queue);
ConcurrentRunner {
actor,
shutdown: Some(shutdown_recv),
waiting: Vec::new(),
recv: messenger_recv,
handlers: Vec::new(),
queue,
bg_runner: runner,
}
.await
}));
Self {
shutdown: shutdown_send,
runtime,
messenger: messenger_send,
}
}
/// Message is guaranteed to be queued immediately
pub fn queue<M: Send + 'static>(
&self,
message: M,
) -> impl Future<Output = Result<A::Response, Error>>
where
A: Handler<M>,
{
if self.runtime.is_finished() {
return futures::future::Either::Left(ready(Err(Error::new(
eyre!("actor runtime has exited"),
ErrorKind::Unknown,
))));
}
let (reply_send, reply_recv) = oneshot::channel();
self.messenger
.send((Box::new(message), reply_send))
.unwrap();
futures::future::Either::Right(
reply_recv
.map_err(|_| Error::new(eyre!("actor runtime has exited"), ErrorKind::Unknown))
.and_then(|a| {
ready(
a.downcast()
.map_err(|_| {
Error::new(
eyre!("received incorrect type in response"),
ErrorKind::Incoherent,
)
})
.map(|a| *a),
)
}),
)
}
pub async fn send<M: Send + 'static>(&self, message: M) -> Result<A::Response, Error>
where
A: Handler<M>,
{
self.queue(message).await
}
pub async fn shutdown(self, strategy: PendingMessageStrategy) {
drop(self.messenger);
let timeout = match strategy {
PendingMessageStrategy::CancelAll => {
self.shutdown.send(()).unwrap();
Some(Duration::from_secs(0))
}
PendingMessageStrategy::FinishCurrentCancelPending { timeout } => {
self.shutdown.send(()).unwrap();
timeout
}
PendingMessageStrategy::FinishAll { timeout } => timeout,
};
let aborter = if let Some(timeout) = timeout {
let hdl = self.runtime.abort_handle();
async move {
tokio::time::sleep(timeout).await;
hdl.abort();
}
.boxed()
} else {
futures::future::pending().boxed()
};
tokio::select! {
_ = aborter => (),
_ = self.runtime => (),
}
}
}

View File

@@ -0,0 +1,148 @@
use std::any::{Any, TypeId};
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Duration;
use futures::future::BoxFuture;
use futures::{Future, FutureExt};
use tokio::sync::oneshot;
#[allow(unused_imports)]
use crate::prelude::*;
use crate::util::actor::background::BackgroundJobQueue;
pub mod background;
pub mod concurrent;
pub mod simple;
pub trait Actor: Sized + Send + 'static {
#[allow(unused_variables)]
fn init(&mut self, jobs: &BackgroundJobQueue) {}
}
pub trait Handler<M: Any + Send>: Actor {
type Response: Any + Send;
/// DRAGONS: this must be correctly implemented bi-directionally in order to work as expected
fn conflicts_with(#[allow(unused_variables)] msg: &M) -> ConflictBuilder<Self> {
ConflictBuilder::everything()
}
fn handle(
&mut self,
msg: M,
jobs: &BackgroundJobQueue,
) -> impl Future<Output = Self::Response> + Send;
}
type ConflictFn<A> = dyn Fn(&dyn Message<A>) -> bool + Send + Sync;
trait Message<A>: Send + Any {
fn conflicts_with(&self) -> Arc<ConflictFn<A>>;
fn handle_with<'a>(
self: Box<Self>,
actor: &'a mut A,
jobs: &'a BackgroundJobQueue,
) -> BoxFuture<'a, Box<dyn Any + Send>>;
}
impl<M: Send + Any, A: Actor> Message<A> for M
where
A: Handler<M>,
{
fn conflicts_with(&self) -> Arc<ConflictFn<A>> {
A::conflicts_with(self).build()
}
fn handle_with<'a>(
self: Box<Self>,
actor: &'a mut A,
jobs: &'a BackgroundJobQueue,
) -> BoxFuture<'a, Box<dyn Any + Send>> {
async move { Box::new(actor.handle(*self, jobs).await) as Box<dyn Any + Send> }.boxed()
}
}
impl<A: Actor> dyn Message<A> {
#[inline]
pub fn is<M: Message<A>>(&self) -> bool {
let t = TypeId::of::<M>();
let concrete = self.type_id();
t == concrete
}
#[inline]
pub unsafe fn downcast_ref_unchecked<M: Message<A>>(&self) -> &M {
debug_assert!(self.is::<M>());
unsafe { &*(self as *const dyn Message<A> as *const M) }
}
#[inline]
fn downcast_ref<M: Message<A>>(&self) -> Option<&M> {
if self.is::<M>() {
unsafe { Some(self.downcast_ref_unchecked()) }
} else {
None
}
}
}
type Request<A> = (Box<dyn Message<A>>, oneshot::Sender<Box<dyn Any + Send>>);
pub enum PendingMessageStrategy {
CancelAll,
FinishCurrentCancelPending { timeout: Option<Duration> },
FinishAll { timeout: Option<Duration> },
}
pub struct ConflictBuilder<A> {
base: bool,
except: BTreeMap<TypeId, Option<Box<dyn Fn(&dyn Message<A>) -> bool + Send + Sync>>>,
}
impl<A: Actor> ConflictBuilder<A> {
pub const fn everything() -> Self {
Self {
base: true,
except: BTreeMap::new(),
}
}
pub const fn nothing() -> Self {
Self {
base: false,
except: BTreeMap::new(),
}
}
pub fn except<M: Any + Send>(mut self) -> Self
where
A: Handler<M>,
{
self.except.insert(TypeId::of::<M>(), None);
self
}
pub fn except_if<M: Any + Send, F: Fn(&M) -> bool + Send + Sync + 'static>(
mut self,
f: F,
) -> Self
where
A: Handler<M>,
{
self.except.insert(
TypeId::of::<M>(),
Some(Box::new(move |m| {
if let Some(m) = m.downcast_ref() {
f(m)
} else {
false
}
})),
);
self
}
fn build(self) -> Arc<ConflictFn<A>> {
Arc::new(move |m| {
self.base
^ if let Some(entry) = self.except.get(&m.type_id()) {
if let Some(f) = entry {
f(m)
} else {
true
}
} else {
false
}
})
}
}

View File

@@ -1,85 +1,14 @@
use std::any::Any;
use std::future::ready;
use std::time::Duration;
use futures::future::BoxFuture;
use futures::future::ready;
use futures::{Future, FutureExt, TryFutureExt};
use helpers::NonDetachingJoinHandle;
use tokio::sync::oneshot::error::TryRecvError;
use tokio::sync::{mpsc, oneshot};
use crate::prelude::*;
use crate::util::Never;
pub trait Actor: Send + 'static {
#[allow(unused_variables)]
fn init(&mut self, jobs: &mut BackgroundJobs) {}
}
pub trait Handler<M>: Actor {
type Response: Any + Send;
fn handle(
&mut self,
msg: M,
jobs: &mut BackgroundJobs,
) -> impl Future<Output = Self::Response> + Send;
}
#[async_trait::async_trait]
trait Message<A>: Send {
async fn handle_with(
self: Box<Self>,
actor: &mut A,
jobs: &mut BackgroundJobs,
) -> Box<dyn Any + Send>;
}
#[async_trait::async_trait]
impl<M: Send, A: Actor> Message<A> for M
where
A: Handler<M>,
{
async fn handle_with(
self: Box<Self>,
actor: &mut A,
jobs: &mut BackgroundJobs,
) -> Box<dyn Any + Send> {
Box::new(actor.handle(*self, jobs).await)
}
}
type Request<A> = (Box<dyn Message<A>>, oneshot::Sender<Box<dyn Any + Send>>);
#[derive(Default)]
pub struct BackgroundJobs {
jobs: Vec<BoxFuture<'static, ()>>,
}
impl BackgroundJobs {
pub fn add_job(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
self.jobs.push(fut.boxed());
}
}
impl Future for BackgroundJobs {
type Output = Never;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let complete = self
.jobs
.iter_mut()
.enumerate()
.filter_map(|(i, f)| match f.poll_unpin(cx) {
std::task::Poll::Pending => None,
std::task::Poll::Ready(_) => Some(i),
})
.collect::<Vec<_>>();
for idx in complete.into_iter().rev() {
#[allow(clippy::let_underscore_future)]
let _ = self.jobs.swap_remove(idx);
}
std::task::Poll::Pending
}
}
use crate::util::actor::background::BackgroundJobQueue;
use crate::util::actor::{Actor, Handler, PendingMessageStrategy, Request};
pub struct SimpleActor<A: Actor> {
shutdown: oneshot::Sender<()>,
@@ -91,19 +20,17 @@ impl<A: Actor> SimpleActor<A> {
let (shutdown_send, mut shutdown_recv) = oneshot::channel();
let (messenger_send, mut messenger_recv) = mpsc::unbounded_channel::<Request<A>>();
let runtime = NonDetachingJoinHandle::from(tokio::spawn(async move {
let mut bg = BackgroundJobs::default();
actor.init(&mut bg);
let (queue, mut runner) = BackgroundJobQueue::new();
actor.init(&queue);
loop {
tokio::select! {
_ = &mut bg => (),
_ = &mut runner => (),
msg = messenger_recv.recv() => match msg {
Some((msg, reply)) if shutdown_recv.try_recv() == Err(TryRecvError::Empty) => {
let mut new_bg = BackgroundJobs::default();
tokio::select! {
res = msg.handle_with(&mut actor, &mut new_bg) => { let _ = reply.send(res); },
_ = &mut bg => (),
res = msg.handle_with(&mut actor, &queue) => { let _ = reply.send(res); },
_ = &mut runner => (),
}
bg.jobs.append(&mut new_bg.jobs);
}
_ => break,
},
@@ -189,9 +116,3 @@ impl<A: Actor> SimpleActor<A> {
}
}
}
pub enum PendingMessageStrategy {
CancelAll,
FinishCurrentCancelPending { timeout: Option<Duration> },
FinishAll { timeout: Option<Duration> },
}