1#![allow(rustdoc::private_doc_tests)]
4
5use crate::{EthApi, shutdown::Shutdown, tasks::block_listener::BlockListener};
6use alloy_network::{AnyHeader, AnyNetwork};
7use alloy_primitives::B256;
8use alloy_provider::Provider;
9use alloy_rpc_types::anvil::Forking;
10use futures::StreamExt;
11use std::fmt;
12use tokio::{runtime::Handle, task::JoinHandle};
13
14pub mod block_listener;
15
16#[derive(Clone)]
18pub struct TaskManager {
19 tokio_handle: Handle,
21 on_shutdown: Shutdown,
23}
24
25impl TaskManager {
26 pub fn new(tokio_handle: Handle, on_shutdown: Shutdown) -> Self {
28 Self { tokio_handle, on_shutdown }
29 }
30
31 pub fn on_shutdown(&self) -> Shutdown {
33 self.on_shutdown.clone()
34 }
35
36 pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
38 self.tokio_handle.spawn(task)
39 }
40
41 pub fn spawn_blocking(
45 &self,
46 task: impl Future<Output = ()> + Send + 'static,
47 ) -> JoinHandle<()> {
48 let handle = self.tokio_handle.clone();
49 self.tokio_handle.spawn_blocking(move || {
50 handle.block_on(task);
51 })
52 }
53
54 pub fn spawn_reset_on_new_polled_blocks<P>(&self, provider: P, api: EthApi)
72 where
73 P: Provider<AnyNetwork> + Clone + Unpin + 'static,
74 {
75 self.spawn_block_poll_listener(provider.clone(), move |hash| {
76 let provider = provider.clone();
77 let api = api.clone();
78 async move {
79 if let Ok(Some(block)) = provider.get_block(hash.into()).await {
80 let _ = api
81 .anvil_reset(Some(Forking {
82 json_rpc_url: None,
83 block_number: Some(block.header.number),
84 }))
85 .await;
86 }
87 }
88 })
89 }
90
91 pub fn spawn_block_poll_listener<P, F, Fut>(&self, provider: P, task_factory: F)
95 where
96 P: Provider<AnyNetwork> + 'static,
97 F: Fn(B256) -> Fut + Unpin + Send + Sync + 'static,
98 Fut: Future<Output = ()> + Send,
99 {
100 let shutdown = self.on_shutdown.clone();
101 self.spawn(async move {
102 let blocks = provider
103 .watch_blocks()
104 .await
105 .unwrap()
106 .into_stream()
107 .flat_map(futures::stream::iter);
108 BlockListener::new(shutdown, blocks, task_factory).await;
109 });
110 }
111
112 pub fn spawn_reset_on_subscribed_blocks<P>(&self, provider: P, api: EthApi)
130 where
131 P: Provider<AnyNetwork> + 'static,
132 {
133 self.spawn_block_subscription(provider, move |header| {
134 let api = api.clone();
135 async move {
136 let _ = api
137 .anvil_reset(Some(Forking {
138 json_rpc_url: None,
139 block_number: Some(header.number),
140 }))
141 .await;
142 }
143 })
144 }
145
146 pub fn spawn_block_subscription<P, F, Fut>(&self, provider: P, task_factory: F)
150 where
151 P: Provider<AnyNetwork> + 'static,
152 F: Fn(alloy_rpc_types::Header<AnyHeader>) -> Fut + Unpin + Send + Sync + 'static,
153 Fut: Future<Output = ()> + Send,
154 {
155 let shutdown = self.on_shutdown.clone();
156 self.spawn(async move {
157 let blocks = provider.subscribe_blocks().await.unwrap().into_stream();
158 BlockListener::new(shutdown, blocks, task_factory).await;
159 });
160 }
161}
162
163impl fmt::Debug for TaskManager {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 f.debug_struct("TaskManager").finish_non_exhaustive()
166 }
167}