Skip to main content

anvil/tasks/
mod.rs

1//! Task management support
2
3#![allow(rustdoc::private_doc_tests)]
4
5use crate::{EthApi, shutdown::Shutdown, tasks::block_listener::BlockListener};
6use alloy_consensus::BlockHeader;
7use alloy_network::{BlockResponse, Network};
8use alloy_primitives::B256;
9use alloy_provider::Provider;
10use alloy_rpc_types::anvil::Forking;
11use foundry_primitives::FoundryNetwork;
12use futures::StreamExt;
13use std::fmt;
14use tokio::{runtime::Handle, task::JoinHandle};
15
16pub mod block_listener;
17
18/// A helper struct for managing additional tokio tasks.
19#[derive(Clone)]
20pub struct TaskManager {
21    /// Tokio runtime handle that's used to spawn futures, See [tokio::runtime::Handle].
22    tokio_handle: Handle,
23    /// A receiver for the shutdown signal
24    on_shutdown: Shutdown,
25}
26
27impl TaskManager {
28    /// Creates a new instance of the task manager
29    pub fn new(tokio_handle: Handle, on_shutdown: Shutdown) -> Self {
30        Self { tokio_handle, on_shutdown }
31    }
32
33    /// Returns a receiver for the shutdown event
34    pub fn on_shutdown(&self) -> Shutdown {
35        self.on_shutdown.clone()
36    }
37
38    /// Spawns the given task.
39    pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
40        self.tokio_handle.spawn(task)
41    }
42
43    /// Spawns the blocking task and returns a handle to it.
44    ///
45    /// Returning the `JoinHandle` allows callers to cancel the task or await its completion.
46    pub fn spawn_blocking(
47        &self,
48        task: impl Future<Output = ()> + Send + 'static,
49    ) -> JoinHandle<()> {
50        let handle = self.tokio_handle.clone();
51        self.tokio_handle.spawn_blocking(move || {
52            handle.block_on(task);
53        })
54    }
55
56    /// Spawns a new task that listens for new blocks and resets the forked provider for every new
57    /// block
58    ///
59    /// ```
60    /// use alloy_network::Ethereum;
61    /// use alloy_provider::RootProvider;
62    /// use anvil::{NodeConfig, spawn};
63    ///
64    /// # async fn t() {
65    /// let endpoint = "http://....";
66    /// let (api, handle) = spawn(NodeConfig::default().with_eth_rpc_url(Some(endpoint))).await;
67    ///
68    /// let provider = RootProvider::connect(endpoint).await.unwrap();
69    ///
70    /// handle.task_manager().spawn_reset_on_new_polled_blocks::<Ethereum, _>(provider, api);
71    /// # }
72    /// ```
73    pub fn spawn_reset_on_new_polled_blocks<N, P>(&self, provider: P, api: EthApi<FoundryNetwork>)
74    where
75        N: Network,
76        P: Provider<N> + Clone + Unpin + 'static,
77    {
78        self.spawn_block_poll_listener(provider.clone(), move |hash| {
79            let provider = provider.clone();
80            let api = api.clone();
81            async move {
82                if let Ok(Some(block)) = provider.get_block(hash.into()).await {
83                    let _ = api
84                        .anvil_reset(Some(Forking {
85                            json_rpc_url: None,
86                            block_number: Some(block.header().number()),
87                        }))
88                        .await;
89                }
90            }
91        })
92    }
93
94    /// Spawns a new [`BlockListener`] task that listens for new blocks (poll-based) See also
95    /// [`Provider::watch_blocks`] and executes the future the `task_factory` returns for the new
96    /// block hash
97    pub fn spawn_block_poll_listener<N, P, F, Fut>(&self, provider: P, task_factory: F)
98    where
99        N: Network,
100        P: Provider<N> + 'static,
101        F: Fn(B256) -> Fut + Unpin + Send + Sync + 'static,
102        Fut: Future<Output = ()> + Send,
103    {
104        let shutdown = self.on_shutdown.clone();
105        self.spawn(async move {
106            let blocks = provider
107                .watch_blocks()
108                .await
109                .unwrap()
110                .into_stream()
111                .flat_map(futures::stream::iter);
112            BlockListener::new(shutdown, blocks, task_factory).await;
113        });
114    }
115
116    /// Spawns a new task that listens for new blocks and resets the forked provider for every new
117    /// block
118    ///
119    /// ```
120    /// use alloy_network::Ethereum;
121    /// use alloy_provider::RootProvider;
122    /// use anvil::{NodeConfig, spawn};
123    ///
124    /// # async fn t() {
125    /// let (api, handle) = spawn(NodeConfig::default().with_eth_rpc_url(Some("http://...."))).await;
126    ///
127    /// let provider = RootProvider::connect("ws://...").await.unwrap();
128    ///
129    /// handle.task_manager().spawn_reset_on_subscribed_blocks::<Ethereum, _>(provider, api);
130    ///
131    /// # }
132    /// ```
133    pub fn spawn_reset_on_subscribed_blocks<N, P>(&self, provider: P, api: EthApi<FoundryNetwork>)
134    where
135        N: Network,
136        P: Provider<N> + 'static,
137    {
138        self.spawn_block_subscription(provider, move |header: N::HeaderResponse| {
139            let api = api.clone();
140            async move {
141                let _ = api
142                    .anvil_reset(Some(Forking {
143                        json_rpc_url: None,
144                        block_number: Some(header.number()),
145                    }))
146                    .await;
147            }
148        })
149    }
150
151    /// Spawns a new [`BlockListener`] task that listens for new blocks (via subscription) See also
152    /// [`Provider::subscribe_blocks()`] and executes the future the `task_factory` returns for the
153    /// new block hash
154    pub fn spawn_block_subscription<N, P, F, Fut>(&self, provider: P, task_factory: F)
155    where
156        N: Network,
157        P: Provider<N> + 'static,
158        F: Fn(N::HeaderResponse) -> Fut + Unpin + Send + Sync + 'static,
159        Fut: Future<Output = ()> + Send,
160    {
161        let shutdown = self.on_shutdown.clone();
162        self.spawn(async move {
163            let blocks = provider.subscribe_blocks().await.unwrap().into_stream();
164            BlockListener::new(shutdown, blocks, task_factory).await;
165        });
166    }
167}
168
169impl fmt::Debug for TaskManager {
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        f.debug_struct("TaskManager").finish_non_exhaustive()
172    }
173}