anvil/tasks/
mod.rs

1//! Task management support
2
3#![allow(rustdoc::private_doc_tests)]
4
5use crate::{EthApi, shutdown::Shutdown, tasks::block_listener::BlockListener};
6use alloy_network::{AnyHeader, AnyNetwork};
7use alloy_primitives::B256;
8use alloy_provider::Provider;
9use alloy_rpc_types::anvil::Forking;
10use futures::StreamExt;
11use std::fmt;
12use tokio::{runtime::Handle, task::JoinHandle};
13
14pub mod block_listener;
15
16/// A helper struct for managing additional tokio tasks.
17#[derive(Clone)]
18pub struct TaskManager {
19    /// Tokio runtime handle that's used to spawn futures, See [tokio::runtime::Handle].
20    tokio_handle: Handle,
21    /// A receiver for the shutdown signal
22    on_shutdown: Shutdown,
23}
24
25impl TaskManager {
26    /// Creates a new instance of the task manager
27    pub fn new(tokio_handle: Handle, on_shutdown: Shutdown) -> Self {
28        Self { tokio_handle, on_shutdown }
29    }
30
31    /// Returns a receiver for the shutdown event
32    pub fn on_shutdown(&self) -> Shutdown {
33        self.on_shutdown.clone()
34    }
35
36    /// Spawns the given task.
37    pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
38        self.tokio_handle.spawn(task)
39    }
40
41    /// Spawns the blocking task and returns a handle to it.
42    ///
43    /// Returning the `JoinHandle` allows callers to cancel the task or await its completion.
44    pub fn spawn_blocking(
45        &self,
46        task: impl Future<Output = ()> + Send + 'static,
47    ) -> JoinHandle<()> {
48        let handle = self.tokio_handle.clone();
49        self.tokio_handle.spawn_blocking(move || {
50            handle.block_on(task);
51        })
52    }
53
54    /// Spawns a new task that listens for new blocks and resets the forked provider for every new
55    /// block
56    ///
57    /// ```
58    /// use alloy_network::Ethereum;
59    /// use alloy_provider::RootProvider;
60    /// use anvil::{NodeConfig, spawn};
61    ///
62    /// # async fn t() {
63    /// let endpoint = "http://....";
64    /// let (api, handle) = spawn(NodeConfig::default().with_eth_rpc_url(Some(endpoint))).await;
65    ///
66    /// let provider = RootProvider::connect_builtin(endpoint).await.unwrap();
67    ///
68    /// handle.task_manager().spawn_reset_on_new_polled_blocks(provider, api);
69    /// # }
70    /// ```
71    pub fn spawn_reset_on_new_polled_blocks<P>(&self, provider: P, api: EthApi)
72    where
73        P: Provider<AnyNetwork> + Clone + Unpin + 'static,
74    {
75        self.spawn_block_poll_listener(provider.clone(), move |hash| {
76            let provider = provider.clone();
77            let api = api.clone();
78            async move {
79                if let Ok(Some(block)) = provider.get_block(hash.into()).await {
80                    let _ = api
81                        .anvil_reset(Some(Forking {
82                            json_rpc_url: None,
83                            block_number: Some(block.header.number),
84                        }))
85                        .await;
86                }
87            }
88        })
89    }
90
91    /// Spawns a new [`BlockListener`] task that listens for new blocks (poll-based) See also
92    /// [`Provider::watch_blocks`] and executes the future the `task_factory` returns for the new
93    /// block hash
94    pub fn spawn_block_poll_listener<P, F, Fut>(&self, provider: P, task_factory: F)
95    where
96        P: Provider<AnyNetwork> + 'static,
97        F: Fn(B256) -> Fut + Unpin + Send + Sync + 'static,
98        Fut: Future<Output = ()> + Send,
99    {
100        let shutdown = self.on_shutdown.clone();
101        self.spawn(async move {
102            let blocks = provider
103                .watch_blocks()
104                .await
105                .unwrap()
106                .into_stream()
107                .flat_map(futures::stream::iter);
108            BlockListener::new(shutdown, blocks, task_factory).await;
109        });
110    }
111
112    /// Spawns a new task that listens for new blocks and resets the forked provider for every new
113    /// block
114    ///
115    /// ```
116    /// use alloy_network::Ethereum;
117    /// use alloy_provider::RootProvider;
118    /// use anvil::{NodeConfig, spawn};
119    ///
120    /// # async fn t() {
121    /// let (api, handle) = spawn(NodeConfig::default().with_eth_rpc_url(Some("http://...."))).await;
122    ///
123    /// let provider = RootProvider::connect_builtin("ws://...").await.unwrap();
124    ///
125    /// handle.task_manager().spawn_reset_on_subscribed_blocks(provider, api);
126    ///
127    /// # }
128    /// ```
129    pub fn spawn_reset_on_subscribed_blocks<P>(&self, provider: P, api: EthApi)
130    where
131        P: Provider<AnyNetwork> + 'static,
132    {
133        self.spawn_block_subscription(provider, move |header| {
134            let api = api.clone();
135            async move {
136                let _ = api
137                    .anvil_reset(Some(Forking {
138                        json_rpc_url: None,
139                        block_number: Some(header.number),
140                    }))
141                    .await;
142            }
143        })
144    }
145
146    /// Spawns a new [`BlockListener`] task that listens for new blocks (via subscription) See also
147    /// [`Provider::subscribe_blocks()`] and executes the future the `task_factory` returns for the
148    /// new block hash
149    pub fn spawn_block_subscription<P, F, Fut>(&self, provider: P, task_factory: F)
150    where
151        P: Provider<AnyNetwork> + 'static,
152        F: Fn(alloy_rpc_types::Header<AnyHeader>) -> Fut + Unpin + Send + Sync + 'static,
153        Fut: Future<Output = ()> + Send,
154    {
155        let shutdown = self.on_shutdown.clone();
156        self.spawn(async move {
157            let blocks = provider.subscribe_blocks().await.unwrap().into_stream();
158            BlockListener::new(shutdown, blocks, task_factory).await;
159        });
160    }
161}
162
163impl fmt::Debug for TaskManager {
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        f.debug_struct("TaskManager").finish_non_exhaustive()
166    }
167}