Skip to main content

anvil/
pubsub.rs

1use crate::{
2    StorageInfo,
3    eth::{backend::notifications::NewBlockNotifications, error::to_rpc_result},
4};
5use alloy_consensus::{BlockHeader, TxReceipt};
6use alloy_network::{AnyRpcTransaction, Network};
7use alloy_primitives::{B256, TxHash};
8use alloy_rpc_types::{FilteredParams, Log, Transaction, pubsub::SubscriptionResult};
9use anvil_core::eth::{block::Block, subscription::SubscriptionId};
10use anvil_rpc::{request::Version, response::ResponseResult};
11use futures::{Stream, StreamExt, channel::mpsc::Receiver, ready};
12use serde::Serialize;
13use std::{
14    collections::VecDeque,
15    pin::Pin,
16    task::{Context, Poll},
17};
18use tokio::sync::mpsc::UnboundedReceiver;
19
20/// Listens for new blocks and matching logs emitted in that block
21pub struct LogsSubscription<N: Network> {
22    pub blocks: NewBlockNotifications,
23    pub storage: StorageInfo<N>,
24    pub filter: FilteredParams,
25    pub queued: VecDeque<Log>,
26    pub id: SubscriptionId,
27}
28
29impl<N: Network> std::fmt::Debug for LogsSubscription<N> {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("LogsSubscription")
32            .field("filter", &self.filter)
33            .field("id", &self.id)
34            .finish_non_exhaustive()
35    }
36}
37
38impl<N: Network> LogsSubscription<N>
39where
40    N::ReceiptEnvelope: TxReceipt<Log = alloy_primitives::Log> + Clone,
41{
42    fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<EthSubscriptionResponse>> {
43        loop {
44            if let Some(log) = self.queued.pop_front() {
45                let params = EthSubscriptionParams {
46                    subscription: self.id.clone(),
47                    result: to_rpc_result(log),
48                };
49                return Poll::Ready(Some(EthSubscriptionResponse::new(params)));
50            }
51
52            if let Some(block) = ready!(self.blocks.poll_next_unpin(cx)) {
53                let b = self.storage.block(block.hash);
54                let receipts = self.storage.receipts(block.hash);
55                if let (Some(receipts), Some(block)) = (receipts, b) {
56                    let logs = filter_logs(block, receipts, &self.filter);
57                    if logs.is_empty() {
58                        // this ensures we poll the receiver until it is pending, in which case the
59                        // underlying `UnboundedReceiver` will register the new waker, see
60                        // [`futures::channel::mpsc::UnboundedReceiver::poll_next()`]
61                        continue;
62                    }
63                    self.queued.extend(logs)
64                }
65            } else {
66                return Poll::Ready(None);
67            }
68
69            if self.queued.is_empty() {
70                return Poll::Pending;
71            }
72        }
73    }
74}
75
76#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
77pub struct EthSubscriptionResponse {
78    jsonrpc: Version,
79    method: &'static str,
80    params: EthSubscriptionParams,
81}
82
83impl EthSubscriptionResponse {
84    pub fn new(params: EthSubscriptionParams) -> Self {
85        Self { jsonrpc: Version::V2, method: "eth_subscription", params }
86    }
87}
88
89/// Represents the `params` field of an `eth_subscription` event
90#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
91pub struct EthSubscriptionParams {
92    subscription: SubscriptionId,
93    #[serde(flatten)]
94    result: ResponseResult,
95}
96
97/// Represents an ethereum Websocket subscription
98pub enum EthSubscription<N: Network> {
99    Logs(Box<LogsSubscription<N>>),
100    Header(NewBlockNotifications, StorageInfo<N>, SubscriptionId),
101    PendingTransactions(Receiver<TxHash>, SubscriptionId),
102    FullPendingTransactions(UnboundedReceiver<AnyRpcTransaction>, SubscriptionId),
103}
104
105impl<N: Network> std::fmt::Debug for EthSubscription<N> {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        match self {
108            Self::Logs(_) => f.debug_tuple("Logs").finish(),
109            Self::Header(..) => f.debug_tuple("Header").finish(),
110            Self::PendingTransactions(..) => f.debug_tuple("PendingTransactions").finish(),
111            Self::FullPendingTransactions(..) => f.debug_tuple("FullPendingTransactions").finish(),
112        }
113    }
114}
115
116impl<N: Network> EthSubscription<N>
117where
118    N::ReceiptEnvelope: TxReceipt<Log = alloy_primitives::Log> + Clone,
119{
120    fn poll_response(&mut self, cx: &mut Context<'_>) -> Poll<Option<EthSubscriptionResponse>> {
121        match self {
122            Self::Logs(listener) => listener.poll(cx),
123            Self::Header(blocks, storage, id) => {
124                // this loop ensures we poll the receiver until it is pending, in which case the
125                // underlying `UnboundedReceiver` will register the new waker, see
126                // [`futures::channel::mpsc::UnboundedReceiver::poll_next()`]
127                loop {
128                    if let Some(block) = ready!(blocks.poll_next_unpin(cx)) {
129                        if let Some(block) = storage.eth_block(block.hash) {
130                            let params = EthSubscriptionParams {
131                                subscription: id.clone(),
132                                result: to_rpc_result(block),
133                            };
134                            return Poll::Ready(Some(EthSubscriptionResponse::new(params)));
135                        }
136                    } else {
137                        return Poll::Ready(None);
138                    }
139                }
140            }
141            Self::PendingTransactions(tx, id) => {
142                let res = ready!(tx.poll_next_unpin(cx))
143                    .map(SubscriptionResult::<Transaction>::TransactionHash)
144                    .map(to_rpc_result)
145                    .map(|result| {
146                        let params = EthSubscriptionParams { subscription: id.clone(), result };
147                        EthSubscriptionResponse::new(params)
148                    });
149                Poll::Ready(res)
150            }
151            Self::FullPendingTransactions(tx, id) => {
152                let res = ready!(tx.poll_recv(cx)).map(to_rpc_result).map(|result| {
153                    let params = EthSubscriptionParams { subscription: id.clone(), result };
154                    EthSubscriptionResponse::new(params)
155                });
156                Poll::Ready(res)
157            }
158        }
159    }
160}
161
162impl<N: Network> Stream for EthSubscription<N>
163where
164    N::ReceiptEnvelope: TxReceipt<Log = alloy_primitives::Log> + Clone,
165{
166    type Item = serde_json::Value;
167
168    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
169        let pin = self.get_mut();
170        match ready!(pin.poll_response(cx)) {
171            None => Poll::Ready(None),
172            Some(res) => Poll::Ready(Some(serde_json::to_value(res).expect("can't fail;"))),
173        }
174    }
175}
176
177/// Returns all the logs that match the given filter
178pub fn filter_logs<R>(block: Block, receipts: Vec<R>, filter: &FilteredParams) -> Vec<Log>
179where
180    R: TxReceipt<Log = alloy_primitives::Log>,
181{
182    /// Determines whether to add this log
183    fn add_log(
184        block_hash: B256,
185        l: &alloy_primitives::Log,
186        block: &Block,
187        params: &FilteredParams,
188    ) -> bool {
189        if params.filter.is_some() {
190            let block_number = block.header.number();
191            if !params.filter_block_range(block_number)
192                || !params.filter_block_hash(block_hash)
193                || !params.filter_address(&l.address)
194                || !params.filter_topics(l.topics())
195            {
196                return false;
197            }
198        }
199        true
200    }
201
202    let block_hash = block.header.hash_slow();
203    let mut logs = vec![];
204    let mut log_index: u32 = 0;
205    for (receipt_index, receipt) in receipts.into_iter().enumerate() {
206        let transaction_hash = block.body.transactions[receipt_index].hash();
207        for log in receipt.logs() {
208            if add_log(block_hash, log, &block, filter) {
209                logs.push(Log {
210                    inner: log.clone(),
211                    block_hash: Some(block_hash),
212                    block_number: Some(block.header.number()),
213                    transaction_hash: Some(transaction_hash),
214                    transaction_index: Some(receipt_index as u64),
215                    log_index: Some(log_index as u64),
216                    removed: false,
217                    block_timestamp: Some(block.header.timestamp()),
218                });
219            }
220            log_index += 1;
221        }
222    }
223    logs
224}