1#![allow(rustdoc::private_doc_tests)]
4
5use crate::{EthApi, shutdown::Shutdown, tasks::block_listener::BlockListener};
6use alloy_consensus::BlockHeader;
7use alloy_network::{BlockResponse, Network};
8use alloy_primitives::B256;
9use alloy_provider::Provider;
10use alloy_rpc_types::anvil::Forking;
11use foundry_primitives::FoundryNetwork;
12use futures::StreamExt;
13use std::fmt;
14use tokio::{runtime::Handle, task::JoinHandle};
15
16pub mod block_listener;
17
18#[derive(Clone)]
20pub struct TaskManager {
21 tokio_handle: Handle,
23 on_shutdown: Shutdown,
25}
26
27impl TaskManager {
28 pub fn new(tokio_handle: Handle, on_shutdown: Shutdown) -> Self {
30 Self { tokio_handle, on_shutdown }
31 }
32
33 pub fn on_shutdown(&self) -> Shutdown {
35 self.on_shutdown.clone()
36 }
37
38 pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
40 self.tokio_handle.spawn(task)
41 }
42
43 pub fn spawn_blocking(
47 &self,
48 task: impl Future<Output = ()> + Send + 'static,
49 ) -> JoinHandle<()> {
50 let handle = self.tokio_handle.clone();
51 self.tokio_handle.spawn_blocking(move || {
52 handle.block_on(task);
53 })
54 }
55
56 pub fn spawn_reset_on_new_polled_blocks<N, P>(&self, provider: P, api: EthApi<FoundryNetwork>)
74 where
75 N: Network,
76 P: Provider<N> + Clone + Unpin + 'static,
77 {
78 self.spawn_block_poll_listener(provider.clone(), move |hash| {
79 let provider = provider.clone();
80 let api = api.clone();
81 async move {
82 if let Ok(Some(block)) = provider.get_block(hash.into()).await {
83 let _ = api
84 .anvil_reset(Some(Forking {
85 json_rpc_url: None,
86 block_number: Some(block.header().number()),
87 }))
88 .await;
89 }
90 }
91 })
92 }
93
94 pub fn spawn_block_poll_listener<N, P, F, Fut>(&self, provider: P, task_factory: F)
98 where
99 N: Network,
100 P: Provider<N> + 'static,
101 F: Fn(B256) -> Fut + Unpin + Send + Sync + 'static,
102 Fut: Future<Output = ()> + Send,
103 {
104 let shutdown = self.on_shutdown.clone();
105 self.spawn(async move {
106 let blocks = provider
107 .watch_blocks()
108 .await
109 .unwrap()
110 .into_stream()
111 .flat_map(futures::stream::iter);
112 BlockListener::new(shutdown, blocks, task_factory).await;
113 });
114 }
115
116 pub fn spawn_reset_on_subscribed_blocks<N, P>(&self, provider: P, api: EthApi<FoundryNetwork>)
134 where
135 N: Network,
136 P: Provider<N> + 'static,
137 {
138 self.spawn_block_subscription(provider, move |header: N::HeaderResponse| {
139 let api = api.clone();
140 async move {
141 let _ = api
142 .anvil_reset(Some(Forking {
143 json_rpc_url: None,
144 block_number: Some(header.number()),
145 }))
146 .await;
147 }
148 })
149 }
150
151 pub fn spawn_block_subscription<N, P, F, Fut>(&self, provider: P, task_factory: F)
155 where
156 N: Network,
157 P: Provider<N> + 'static,
158 F: Fn(N::HeaderResponse) -> Fut + Unpin + Send + Sync + 'static,
159 Fut: Future<Output = ()> + Send,
160 {
161 let shutdown = self.on_shutdown.clone();
162 self.spawn(async move {
163 let blocks = provider.subscribe_blocks().await.unwrap().into_stream();
164 BlockListener::new(shutdown, blocks, task_factory).await;
165 });
166 }
167}
168
169impl fmt::Debug for TaskManager {
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 f.debug_struct("TaskManager").finish_non_exhaustive()
172 }
173}