1use crate::{
4 backend::{RevertStateSnapshotAction, StateSnapshot},
5 state_snapshot::StateSnapshots,
6};
7use alloy_network::Network;
8use alloy_primitives::{Address, B256, U256};
9use alloy_rpc_types::BlockId;
10use foundry_fork_db::{BlockchainDb, DatabaseError, ForkBlockEnv, SharedBackend};
11use parking_lot::Mutex;
12use revm::{
13 Database, DatabaseCommit,
14 bytecode::Bytecode,
15 context::BlockEnv,
16 database::{CacheDB, DatabaseRef},
17 primitives::AddressMap,
18 state::{Account, AccountInfo},
19};
20use std::sync::Arc;
21
22#[derive(Clone, Debug)]
29pub struct ForkedDatabase<N: Network, B: ForkBlockEnv = BlockEnv> {
30 backend: SharedBackend<N, B>,
34 cache_db: CacheDB<SharedBackend<N, B>>,
40 db: BlockchainDb<B>,
44 state_snapshots: Arc<Mutex<StateSnapshots<ForkDbStateSnapshot<N, B>>>>,
46}
47
48impl<N: Network, B: ForkBlockEnv> ForkedDatabase<N, B> {
49 pub fn new(backend: SharedBackend<N, B>, db: BlockchainDb<B>) -> Self {
51 Self {
52 cache_db: CacheDB::new(backend.clone()),
53 backend,
54 db,
55 state_snapshots: Arc::new(Mutex::new(Default::default())),
56 }
57 }
58
59 pub const fn database(&self) -> &CacheDB<SharedBackend<N, B>> {
60 &self.cache_db
61 }
62
63 pub const fn database_mut(&mut self) -> &mut CacheDB<SharedBackend<N, B>> {
64 &mut self.cache_db
65 }
66
67 pub const fn state_snapshots(&self) -> &Arc<Mutex<StateSnapshots<ForkDbStateSnapshot<N, B>>>> {
68 &self.state_snapshots
69 }
70
71 pub fn reset(
73 &mut self,
74 _urls: Vec<String>,
75 block_number: impl Into<BlockId>,
76 ) -> Result<(), String> {
77 self.backend.set_pinned_block(block_number).map_err(|err| err.to_string())?;
78
79 self.inner().db().clear();
83 self.cache_db = CacheDB::new(self.backend.clone());
85 trace!(target: "backend::forkdb", "Cleared database");
86 Ok(())
87 }
88
89 pub fn flush_cache(&self) {
91 self.db.cache().flush()
92 }
93
94 pub const fn inner(&self) -> &BlockchainDb<B> {
96 &self.db
97 }
98
99 pub fn create_state_snapshot(&self) -> ForkDbStateSnapshot<N, B> {
100 let db = self.db.db();
101 let state_snapshot = StateSnapshot {
102 accounts: db.accounts.read().clone(),
103 storage: db.storage.read().clone(),
104 block_hashes: db.block_hashes.read().clone(),
105 };
106 ForkDbStateSnapshot { local: self.cache_db.clone(), state_snapshot }
107 }
108
109 pub fn insert_state_snapshot(&self) -> U256 {
110 let state_snapshot = self.create_state_snapshot();
111 let mut state_snapshots = self.state_snapshots().lock();
112 let id = state_snapshots.insert(state_snapshot);
113 trace!(target: "backend::forkdb", "Created new snapshot {}", id);
114 id
115 }
116
117 pub fn revert_state_snapshot(&mut self, id: U256, action: RevertStateSnapshotAction) -> bool {
119 let state_snapshot = { self.state_snapshots().lock().remove_at(id) };
120 if let Some(state_snapshot) = state_snapshot {
121 if action.is_keep() {
122 self.state_snapshots().lock().insert_at(state_snapshot.clone(), id);
123 }
124 let ForkDbStateSnapshot {
125 local,
126 state_snapshot: StateSnapshot { accounts, storage, block_hashes },
127 } = state_snapshot;
128 let db = self.inner().db();
129 {
130 let mut accounts_lock = db.accounts.write();
131 accounts_lock.clear();
132 accounts_lock.extend(accounts);
133 }
134 {
135 let mut storage_lock = db.storage.write();
136 storage_lock.clear();
137 storage_lock.extend(storage);
138 }
139 {
140 let mut block_hashes_lock = db.block_hashes.write();
141 block_hashes_lock.clear();
142 block_hashes_lock.extend(block_hashes);
143 }
144
145 self.cache_db = local;
146
147 trace!(target: "backend::forkdb", "Reverted snapshot {}", id);
148 true
149 } else {
150 warn!(target: "backend::forkdb", "No snapshot to revert for {}", id);
151 false
152 }
153 }
154}
155
156impl<N: Network, B: ForkBlockEnv> Database for ForkedDatabase<N, B> {
157 type Error = DatabaseError;
158
159 fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
160 Database::basic(&mut self.cache_db, address)
164 }
165
166 fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
167 Database::code_by_hash(&mut self.cache_db, code_hash)
168 }
169
170 fn storage(&mut self, address: Address, index: U256) -> Result<U256, Self::Error> {
171 Database::storage(&mut self.cache_db, address, index)
172 }
173
174 fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
175 Database::block_hash(&mut self.cache_db, number)
176 }
177}
178
179impl<N: Network, B: ForkBlockEnv> DatabaseRef for ForkedDatabase<N, B> {
180 type Error = DatabaseError;
181
182 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
183 self.cache_db.basic_ref(address)
184 }
185
186 fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
187 self.cache_db.code_by_hash_ref(code_hash)
188 }
189
190 fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
191 DatabaseRef::storage_ref(&self.cache_db, address, index)
192 }
193
194 fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
195 self.cache_db.block_hash_ref(number)
196 }
197}
198
199impl<N: Network, B: ForkBlockEnv> DatabaseCommit for ForkedDatabase<N, B> {
200 fn commit(&mut self, changes: AddressMap<Account>) {
201 self.database_mut().commit(changes)
202 }
203}
204
205#[derive(Clone, Debug)]
209pub struct ForkDbStateSnapshot<N: Network, B: ForkBlockEnv = BlockEnv> {
210 pub local: CacheDB<SharedBackend<N, B>>,
211 pub state_snapshot: StateSnapshot,
212}
213
214impl<N: Network, B: ForkBlockEnv> ForkDbStateSnapshot<N, B> {
215 fn storage_from_snapshot_or_backend(
217 &self,
218 address: Address,
219 index: U256,
220 ) -> Result<U256, DatabaseError> {
221 if let Some(val) = self.state_snapshot.storage.get(&address).and_then(|s| s.get(&index)) {
223 return Ok(*val);
224 }
225 DatabaseRef::storage_ref(&self.local, address, index)
227 }
228}
229
230impl<N: Network, B: ForkBlockEnv> DatabaseRef for ForkDbStateSnapshot<N, B> {
234 type Error = DatabaseError;
235
236 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
237 match self.local.cache.accounts.get(&address) {
238 Some(account) => Ok(Some(account.info.clone())),
239 None => {
240 let mut acc = self.state_snapshot.accounts.get(&address).cloned();
241
242 if acc.is_none() {
243 acc = self.local.basic_ref(address)?;
244 }
245 Ok(acc)
246 }
247 }
248 }
249
250 fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
251 self.local.code_by_hash_ref(code_hash)
252 }
253
254 fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
255 match self.local.cache.accounts.get(&address) {
256 Some(account) => match account.storage.get(&index) {
257 Some(entry) => Ok(*entry),
258 None => self.storage_from_snapshot_or_backend(address, index),
259 },
260 None => self.storage_from_snapshot_or_backend(address, index),
261 }
262 }
263
264 fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
265 match self.state_snapshot.block_hashes.get(&U256::from(number)).copied() {
266 None => self.local.block_hash_ref(number),
267 Some(block_hash) => Ok(block_hash),
268 }
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::backend::BlockchainDbMeta;
276 use foundry_common::provider::get_http_provider;
277
278 #[tokio::test(flavor = "multi_thread")]
281 async fn fork_db_insert_basic_default() {
282 let rpc = foundry_test_utils::rpc::next_http_rpc_endpoint();
283 let provider = get_http_provider(rpc.clone());
284 let meta = BlockchainDbMeta::new(BlockEnv::default(), rpc);
285
286 let db = BlockchainDb::new(meta, None);
287
288 let backend = SharedBackend::spawn_backend(Arc::new(provider), db.clone(), None).await;
289
290 let mut db = ForkedDatabase::new(backend, db);
291 let address = Address::random();
292
293 let info = Database::basic(&mut db, address).unwrap();
294 assert!(info.is_some());
295 let mut info = info.unwrap();
296 info.balance = U256::from(500u64);
297
298 db.database_mut().insert_account_info(address, info.clone());
300
301 let loaded = Database::basic(&mut db, address).unwrap();
302 assert!(loaded.is_some());
303 assert_eq!(loaded.unwrap(), info);
304 }
305
306 #[tokio::test(flavor = "multi_thread")]
310 async fn fork_db_state_snapshot_reads_storage_from_snapshot() {
311 let rpc = foundry_test_utils::rpc::next_http_rpc_endpoint();
312 let provider = get_http_provider(rpc.clone());
313 let meta = BlockchainDbMeta::new(BlockEnv::default(), rpc);
314 let db = BlockchainDb::new(meta, None);
315 let backend = SharedBackend::spawn_backend(Arc::new(provider), db, None).await;
316
317 let address = Address::random();
318 let slot = U256::from(42u64);
319 let expected = U256::from(0xdeadbeefu64);
320
321 let mut state_snapshot = StateSnapshot::default();
322 state_snapshot.storage.entry(address).or_default().insert(slot, expected);
323
324 let snapshot = ForkDbStateSnapshot { local: CacheDB::new(backend), state_snapshot };
325
326 let got = DatabaseRef::storage_ref(&snapshot, address, slot).unwrap();
327 assert_eq!(got, expected);
328 }
329}