anvil/tasks/
mod.rs

1//! Task management support
2
3#![allow(rustdoc::private_doc_tests)]
4
5use crate::{shutdown::Shutdown, tasks::block_listener::BlockListener, EthApi};
6use alloy_network::{AnyHeader, AnyNetwork};
7use alloy_primitives::B256;
8use alloy_provider::Provider;
9use alloy_rpc_types::anvil::Forking;
10use futures::StreamExt;
11use std::{fmt, future::Future};
12use tokio::{runtime::Handle, task::JoinHandle};
13
14pub mod block_listener;
15
16/// A helper struct for managing additional tokio tasks.
17#[derive(Clone)]
18pub struct TaskManager {
19    /// Tokio runtime handle that's used to spawn futures, See [tokio::runtime::Handle].
20    tokio_handle: Handle,
21    /// A receiver for the shutdown signal
22    on_shutdown: Shutdown,
23}
24
25impl TaskManager {
26    /// Creates a new instance of the task manager
27    pub fn new(tokio_handle: Handle, on_shutdown: Shutdown) -> Self {
28        Self { tokio_handle, on_shutdown }
29    }
30
31    /// Returns a receiver for the shutdown event
32    pub fn on_shutdown(&self) -> Shutdown {
33        self.on_shutdown.clone()
34    }
35
36    /// Spawns the given task.
37    pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
38        self.tokio_handle.spawn(task)
39    }
40
41    /// Spawns the blocking task.
42    pub fn spawn_blocking(&self, task: impl Future<Output = ()> + Send + 'static) {
43        let handle = self.tokio_handle.clone();
44        self.tokio_handle.spawn_blocking(move || {
45            handle.block_on(task);
46        });
47    }
48
49    /// Spawns a new task that listens for new blocks and resets the forked provider for every new
50    /// block
51    ///
52    /// ```
53    /// use alloy_network::Ethereum;
54    /// use alloy_provider::RootProvider;
55    /// use anvil::{spawn, NodeConfig};
56    ///
57    /// # async fn t() {
58    /// let endpoint = "http://....";
59    /// let (api, handle) = spawn(NodeConfig::default().with_eth_rpc_url(Some(endpoint))).await;
60    ///
61    /// let provider = RootProvider::connect_builtin(endpoint).await.unwrap();
62    ///
63    /// handle.task_manager().spawn_reset_on_new_polled_blocks(provider, api);
64    /// # }
65    /// ```
66    pub fn spawn_reset_on_new_polled_blocks<P>(&self, provider: P, api: EthApi)
67    where
68        P: Provider<AnyNetwork> + Clone + Unpin + 'static,
69    {
70        self.spawn_block_poll_listener(provider.clone(), move |hash| {
71            let provider = provider.clone();
72            let api = api.clone();
73            async move {
74                if let Ok(Some(block)) = provider.get_block(hash.into()).await {
75                    let _ = api
76                        .anvil_reset(Some(Forking {
77                            json_rpc_url: None,
78                            block_number: Some(block.header.number),
79                        }))
80                        .await;
81                }
82            }
83        })
84    }
85
86    /// Spawns a new [`BlockListener`] task that listens for new blocks (poll-based) See also
87    /// [`Provider::watch_blocks`] and executes the future the `task_factory` returns for the new
88    /// block hash
89    pub fn spawn_block_poll_listener<P, F, Fut>(&self, provider: P, task_factory: F)
90    where
91        P: Provider<AnyNetwork> + 'static,
92        F: Fn(B256) -> Fut + Unpin + Send + Sync + 'static,
93        Fut: Future<Output = ()> + Send,
94    {
95        let shutdown = self.on_shutdown.clone();
96        self.spawn(async move {
97            let blocks = provider
98                .watch_blocks()
99                .await
100                .unwrap()
101                .into_stream()
102                .flat_map(futures::stream::iter);
103            BlockListener::new(shutdown, blocks, task_factory).await;
104        });
105    }
106
107    /// Spawns a new task that listens for new blocks and resets the forked provider for every new
108    /// block
109    ///
110    /// ```
111    /// use alloy_network::Ethereum;
112    /// use alloy_provider::RootProvider;
113    /// use anvil::{spawn, NodeConfig};
114    ///
115    /// # async fn t() {
116    /// let (api, handle) = spawn(NodeConfig::default().with_eth_rpc_url(Some("http://...."))).await;
117    ///
118    /// let provider = RootProvider::connect_builtin("ws://...").await.unwrap();
119    ///
120    /// handle.task_manager().spawn_reset_on_subscribed_blocks(provider, api);
121    ///
122    /// # }
123    /// ```
124    pub fn spawn_reset_on_subscribed_blocks<P>(&self, provider: P, api: EthApi)
125    where
126        P: Provider<AnyNetwork> + 'static,
127    {
128        self.spawn_block_subscription(provider, move |header| {
129            let api = api.clone();
130            async move {
131                let _ = api
132                    .anvil_reset(Some(Forking {
133                        json_rpc_url: None,
134                        block_number: Some(header.number),
135                    }))
136                    .await;
137            }
138        })
139    }
140
141    /// Spawns a new [`BlockListener`] task that listens for new blocks (via subscription) See also
142    /// [`Provider::subscribe_blocks()`] and executes the future the `task_factory` returns for the
143    /// new block hash
144    pub fn spawn_block_subscription<P, F, Fut>(&self, provider: P, task_factory: F)
145    where
146        P: Provider<AnyNetwork> + 'static,
147        F: Fn(alloy_rpc_types::Header<AnyHeader>) -> Fut + Unpin + Send + Sync + 'static,
148        Fut: Future<Output = ()> + Send,
149    {
150        let shutdown = self.on_shutdown.clone();
151        self.spawn(async move {
152            let blocks = provider.subscribe_blocks().await.unwrap().into_stream();
153            BlockListener::new(shutdown, blocks, task_factory).await;
154        });
155    }
156}
157
158impl fmt::Debug for TaskManager {
159    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160        f.debug_struct("TaskManager").finish_non_exhaustive()
161    }
162}