1use crate::{
4 backend::{RevertStateSnapshotAction, StateSnapshot},
5 state_snapshot::StateSnapshots,
6};
7use alloy_primitives::{map::HashMap, Address, B256, U256};
8use alloy_rpc_types::BlockId;
9use foundry_fork_db::{BlockchainDb, DatabaseError, SharedBackend};
10use parking_lot::Mutex;
11use revm::{
12 db::{CacheDB, DatabaseRef},
13 primitives::{Account, AccountInfo, Bytecode},
14 Database, DatabaseCommit,
15};
16use std::sync::Arc;
17
18#[derive(Clone, Debug)]
25pub struct ForkedDatabase {
26 backend: SharedBackend,
30 cache_db: CacheDB<SharedBackend>,
36 db: BlockchainDb,
40 state_snapshots: Arc<Mutex<StateSnapshots<ForkDbStateSnapshot>>>,
42}
43
44impl ForkedDatabase {
45 pub fn new(backend: SharedBackend, db: BlockchainDb) -> Self {
47 Self {
48 cache_db: CacheDB::new(backend.clone()),
49 backend,
50 db,
51 state_snapshots: Arc::new(Mutex::new(Default::default())),
52 }
53 }
54
55 pub fn database(&self) -> &CacheDB<SharedBackend> {
56 &self.cache_db
57 }
58
59 pub fn database_mut(&mut self) -> &mut CacheDB<SharedBackend> {
60 &mut self.cache_db
61 }
62
63 pub fn state_snapshots(&self) -> &Arc<Mutex<StateSnapshots<ForkDbStateSnapshot>>> {
64 &self.state_snapshots
65 }
66
67 pub fn reset(
69 &mut self,
70 _url: Option<String>,
71 block_number: impl Into<BlockId>,
72 ) -> Result<(), String> {
73 self.backend.set_pinned_block(block_number).map_err(|err| err.to_string())?;
74
75 self.inner().db().clear();
79 self.cache_db = CacheDB::new(self.backend.clone());
81 trace!(target: "backend::forkdb", "Cleared database");
82 Ok(())
83 }
84
85 pub fn flush_cache(&self) {
87 self.db.cache().flush()
88 }
89
90 pub fn inner(&self) -> &BlockchainDb {
92 &self.db
93 }
94
95 pub fn create_state_snapshot(&self) -> ForkDbStateSnapshot {
96 let db = self.db.db();
97 let state_snapshot = StateSnapshot {
98 accounts: db.accounts.read().clone(),
99 storage: db.storage.read().clone(),
100 block_hashes: db.block_hashes.read().clone(),
101 };
102 ForkDbStateSnapshot { local: self.cache_db.clone(), state_snapshot }
103 }
104
105 pub fn insert_state_snapshot(&self) -> U256 {
106 let state_snapshot = self.create_state_snapshot();
107 let mut state_snapshots = self.state_snapshots().lock();
108 let id = state_snapshots.insert(state_snapshot);
109 trace!(target: "backend::forkdb", "Created new snapshot {}", id);
110 id
111 }
112
113 pub fn revert_state_snapshot(&mut self, id: U256, action: RevertStateSnapshotAction) -> bool {
115 let state_snapshot = { self.state_snapshots().lock().remove_at(id) };
116 if let Some(state_snapshot) = state_snapshot {
117 if action.is_keep() {
118 self.state_snapshots().lock().insert_at(state_snapshot.clone(), id);
119 }
120 let ForkDbStateSnapshot {
121 local,
122 state_snapshot: StateSnapshot { accounts, storage, block_hashes },
123 } = state_snapshot;
124 let db = self.inner().db();
125 {
126 let mut accounts_lock = db.accounts.write();
127 accounts_lock.clear();
128 accounts_lock.extend(accounts);
129 }
130 {
131 let mut storage_lock = db.storage.write();
132 storage_lock.clear();
133 storage_lock.extend(storage);
134 }
135 {
136 let mut block_hashes_lock = db.block_hashes.write();
137 block_hashes_lock.clear();
138 block_hashes_lock.extend(block_hashes);
139 }
140
141 self.cache_db = local;
142
143 trace!(target: "backend::forkdb", "Reverted snapshot {}", id);
144 true
145 } else {
146 warn!(target: "backend::forkdb", "No snapshot to revert for {}", id);
147 false
148 }
149 }
150}
151
152impl Database for ForkedDatabase {
153 type Error = DatabaseError;
154
155 fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
156 Database::basic(&mut self.cache_db, address)
160 }
161
162 fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
163 Database::code_by_hash(&mut self.cache_db, code_hash)
164 }
165
166 fn storage(&mut self, address: Address, index: U256) -> Result<U256, Self::Error> {
167 Database::storage(&mut self.cache_db, address, index)
168 }
169
170 fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
171 Database::block_hash(&mut self.cache_db, number)
172 }
173}
174
175impl DatabaseRef for ForkedDatabase {
176 type Error = DatabaseError;
177
178 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
179 self.cache_db.basic_ref(address)
180 }
181
182 fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
183 self.cache_db.code_by_hash_ref(code_hash)
184 }
185
186 fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
187 DatabaseRef::storage_ref(&self.cache_db, address, index)
188 }
189
190 fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
191 self.cache_db.block_hash_ref(number)
192 }
193}
194
195impl DatabaseCommit for ForkedDatabase {
196 fn commit(&mut self, changes: HashMap<Address, Account>) {
197 self.database_mut().commit(changes)
198 }
199}
200
201#[derive(Clone, Debug)]
205pub struct ForkDbStateSnapshot {
206 pub local: CacheDB<SharedBackend>,
207 pub state_snapshot: StateSnapshot,
208}
209
210impl ForkDbStateSnapshot {
211 fn get_storage(&self, address: Address, index: U256) -> Option<U256> {
212 self.local.accounts.get(&address).and_then(|account| account.storage.get(&index)).copied()
213 }
214}
215
216impl DatabaseRef for ForkDbStateSnapshot {
220 type Error = DatabaseError;
221
222 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
223 match self.local.accounts.get(&address) {
224 Some(account) => Ok(Some(account.info.clone())),
225 None => {
226 let mut acc = self.state_snapshot.accounts.get(&address).cloned();
227
228 if acc.is_none() {
229 acc = self.local.basic_ref(address)?;
230 }
231 Ok(acc)
232 }
233 }
234 }
235
236 fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
237 self.local.code_by_hash_ref(code_hash)
238 }
239
240 fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
241 match self.local.accounts.get(&address) {
242 Some(account) => match account.storage.get(&index) {
243 Some(entry) => Ok(*entry),
244 None => match self.get_storage(address, index) {
245 None => DatabaseRef::storage_ref(&self.local, address, index),
246 Some(storage) => Ok(storage),
247 },
248 },
249 None => match self.get_storage(address, index) {
250 None => DatabaseRef::storage_ref(&self.local, address, index),
251 Some(storage) => Ok(storage),
252 },
253 }
254 }
255
256 fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
257 match self.state_snapshot.block_hashes.get(&U256::from(number)).copied() {
258 None => self.local.block_hash_ref(number),
259 Some(block_hash) => Ok(block_hash),
260 }
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use crate::backend::BlockchainDbMeta;
268 use foundry_common::provider::get_http_provider;
269 use std::collections::BTreeSet;
270
271 #[tokio::test(flavor = "multi_thread")]
274 async fn fork_db_insert_basic_default() {
275 let rpc = foundry_test_utils::rpc::next_http_rpc_endpoint();
276 let provider = get_http_provider(rpc.clone());
277 let meta = BlockchainDbMeta {
278 cfg_env: Default::default(),
279 block_env: Default::default(),
280 hosts: BTreeSet::from([rpc]),
281 };
282 let db = BlockchainDb::new(meta, None);
283
284 let backend = SharedBackend::spawn_backend(Arc::new(provider), db.clone(), None).await;
285
286 let mut db = ForkedDatabase::new(backend, db);
287 let address = Address::random();
288
289 let info = Database::basic(&mut db, address).unwrap();
290 assert!(info.is_some());
291 let mut info = info.unwrap();
292 info.balance = U256::from(500u64);
293
294 db.database_mut().insert_account_info(address, info.clone());
296
297 let loaded = Database::basic(&mut db, address).unwrap();
298 assert!(loaded.is_some());
299 assert_eq!(loaded.unwrap(), info);
300 }
301}