1use crate::{
4 backend::{RevertStateSnapshotAction, StateSnapshot},
5 state_snapshot::StateSnapshots,
6};
7use alloy_primitives::{map::HashMap, Address, B256, U256};
8use alloy_rpc_types::BlockId;
9use foundry_fork_db::{BlockchainDb, DatabaseError, SharedBackend};
10use parking_lot::Mutex;
11use revm::{
12 bytecode::Bytecode,
13 database::{CacheDB, DatabaseRef},
14 state::{Account, AccountInfo},
15 Database, DatabaseCommit,
16};
17use std::sync::Arc;
18
19#[derive(Clone, Debug)]
26pub struct ForkedDatabase {
27 backend: SharedBackend,
31 cache_db: CacheDB<SharedBackend>,
37 db: BlockchainDb,
41 state_snapshots: Arc<Mutex<StateSnapshots<ForkDbStateSnapshot>>>,
43}
44
45impl ForkedDatabase {
46 pub fn new(backend: SharedBackend, db: BlockchainDb) -> Self {
48 Self {
49 cache_db: CacheDB::new(backend.clone()),
50 backend,
51 db,
52 state_snapshots: Arc::new(Mutex::new(Default::default())),
53 }
54 }
55
56 pub fn database(&self) -> &CacheDB<SharedBackend> {
57 &self.cache_db
58 }
59
60 pub fn database_mut(&mut self) -> &mut CacheDB<SharedBackend> {
61 &mut self.cache_db
62 }
63
64 pub fn state_snapshots(&self) -> &Arc<Mutex<StateSnapshots<ForkDbStateSnapshot>>> {
65 &self.state_snapshots
66 }
67
68 pub fn reset(
70 &mut self,
71 _url: Option<String>,
72 block_number: impl Into<BlockId>,
73 ) -> Result<(), String> {
74 self.backend.set_pinned_block(block_number).map_err(|err| err.to_string())?;
75
76 self.inner().db().clear();
80 self.cache_db = CacheDB::new(self.backend.clone());
82 trace!(target: "backend::forkdb", "Cleared database");
83 Ok(())
84 }
85
86 pub fn flush_cache(&self) {
88 self.db.cache().flush()
89 }
90
91 pub fn inner(&self) -> &BlockchainDb {
93 &self.db
94 }
95
96 pub fn create_state_snapshot(&self) -> ForkDbStateSnapshot {
97 let db = self.db.db();
98 let state_snapshot = StateSnapshot {
99 accounts: db.accounts.read().clone(),
100 storage: db.storage.read().clone(),
101 block_hashes: db.block_hashes.read().clone(),
102 };
103 ForkDbStateSnapshot { local: self.cache_db.clone(), state_snapshot }
104 }
105
106 pub fn insert_state_snapshot(&self) -> U256 {
107 let state_snapshot = self.create_state_snapshot();
108 let mut state_snapshots = self.state_snapshots().lock();
109 let id = state_snapshots.insert(state_snapshot);
110 trace!(target: "backend::forkdb", "Created new snapshot {}", id);
111 id
112 }
113
114 pub fn revert_state_snapshot(&mut self, id: U256, action: RevertStateSnapshotAction) -> bool {
116 let state_snapshot = { self.state_snapshots().lock().remove_at(id) };
117 if let Some(state_snapshot) = state_snapshot {
118 if action.is_keep() {
119 self.state_snapshots().lock().insert_at(state_snapshot.clone(), id);
120 }
121 let ForkDbStateSnapshot {
122 local,
123 state_snapshot: StateSnapshot { accounts, storage, block_hashes },
124 } = state_snapshot;
125 let db = self.inner().db();
126 {
127 let mut accounts_lock = db.accounts.write();
128 accounts_lock.clear();
129 accounts_lock.extend(accounts);
130 }
131 {
132 let mut storage_lock = db.storage.write();
133 storage_lock.clear();
134 storage_lock.extend(storage);
135 }
136 {
137 let mut block_hashes_lock = db.block_hashes.write();
138 block_hashes_lock.clear();
139 block_hashes_lock.extend(block_hashes);
140 }
141
142 self.cache_db = local;
143
144 trace!(target: "backend::forkdb", "Reverted snapshot {}", id);
145 true
146 } else {
147 warn!(target: "backend::forkdb", "No snapshot to revert for {}", id);
148 false
149 }
150 }
151}
152
153impl Database for ForkedDatabase {
154 type Error = DatabaseError;
155
156 fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
157 Database::basic(&mut self.cache_db, address)
161 }
162
163 fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
164 Database::code_by_hash(&mut self.cache_db, code_hash)
165 }
166
167 fn storage(&mut self, address: Address, index: U256) -> Result<U256, Self::Error> {
168 Database::storage(&mut self.cache_db, address, index)
169 }
170
171 fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
172 Database::block_hash(&mut self.cache_db, number)
173 }
174}
175
176impl DatabaseRef for ForkedDatabase {
177 type Error = DatabaseError;
178
179 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
180 self.cache_db.basic_ref(address)
181 }
182
183 fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
184 self.cache_db.code_by_hash_ref(code_hash)
185 }
186
187 fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
188 DatabaseRef::storage_ref(&self.cache_db, address, index)
189 }
190
191 fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
192 self.cache_db.block_hash_ref(number)
193 }
194}
195
196impl DatabaseCommit for ForkedDatabase {
197 fn commit(&mut self, changes: HashMap<Address, Account>) {
198 self.database_mut().commit(changes)
199 }
200}
201
202#[derive(Clone, Debug)]
206pub struct ForkDbStateSnapshot {
207 pub local: CacheDB<SharedBackend>,
208 pub state_snapshot: StateSnapshot,
209}
210
211impl ForkDbStateSnapshot {
212 fn get_storage(&self, address: Address, index: U256) -> Option<U256> {
213 self.local
214 .cache
215 .accounts
216 .get(&address)
217 .and_then(|account| account.storage.get(&index))
218 .copied()
219 }
220}
221
222impl DatabaseRef for ForkDbStateSnapshot {
226 type Error = DatabaseError;
227
228 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
229 match self.local.cache.accounts.get(&address) {
230 Some(account) => Ok(Some(account.info.clone())),
231 None => {
232 let mut acc = self.state_snapshot.accounts.get(&address).cloned();
233
234 if acc.is_none() {
235 acc = self.local.basic_ref(address)?;
236 }
237 Ok(acc)
238 }
239 }
240 }
241
242 fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
243 self.local.code_by_hash_ref(code_hash)
244 }
245
246 fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
247 match self.local.cache.accounts.get(&address) {
248 Some(account) => match account.storage.get(&index) {
249 Some(entry) => Ok(*entry),
250 None => match self.get_storage(address, index) {
251 None => DatabaseRef::storage_ref(&self.local, address, index),
252 Some(storage) => Ok(storage),
253 },
254 },
255 None => match self.get_storage(address, index) {
256 None => DatabaseRef::storage_ref(&self.local, address, index),
257 Some(storage) => Ok(storage),
258 },
259 }
260 }
261
262 fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
263 match self.state_snapshot.block_hashes.get(&U256::from(number)).copied() {
264 None => self.local.block_hash_ref(number),
265 Some(block_hash) => Ok(block_hash),
266 }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::backend::BlockchainDbMeta;
274 use foundry_common::provider::get_http_provider;
275 use std::collections::BTreeSet;
276
277 #[tokio::test(flavor = "multi_thread")]
280 async fn fork_db_insert_basic_default() {
281 let rpc = foundry_test_utils::rpc::next_http_rpc_endpoint();
282 let provider = get_http_provider(rpc.clone());
283 let meta = BlockchainDbMeta { block_env: Default::default(), hosts: BTreeSet::from([rpc]) };
284 let db = BlockchainDb::new(meta, None);
285
286 let backend = SharedBackend::spawn_backend(Arc::new(provider), db.clone(), None).await;
287
288 let mut db = ForkedDatabase::new(backend, db);
289 let address = Address::random();
290
291 let info = Database::basic(&mut db, address).unwrap();
292 assert!(info.is_some());
293 let mut info = info.unwrap();
294 info.balance = U256::from(500u64);
295
296 db.database_mut().insert_account_info(address, info.clone());
298
299 let loaded = Database::basic(&mut db, address).unwrap();
300 assert!(loaded.is_some());
301 assert_eq!(loaded.unwrap(), info);
302 }
303}