1use crate::{
4 backend::{RevertStateSnapshotAction, StateSnapshot},
5 state_snapshot::StateSnapshots,
6};
7use alloy_network::Network;
8use alloy_primitives::{Address, B256, U256, map::HashMap};
9use alloy_rpc_types::BlockId;
10use foundry_fork_db::{BlockchainDb, DatabaseError, SharedBackend};
11use parking_lot::Mutex;
12use revm::{
13 Database, DatabaseCommit,
14 bytecode::Bytecode,
15 database::{CacheDB, DatabaseRef},
16 state::{Account, AccountInfo},
17};
18use std::sync::Arc;
19
20#[derive(Clone, Debug)]
27pub struct ForkedDatabase<N: Network> {
28 backend: SharedBackend<N>,
32 cache_db: CacheDB<SharedBackend<N>>,
38 db: BlockchainDb,
42 state_snapshots: Arc<Mutex<StateSnapshots<ForkDbStateSnapshot<N>>>>,
44}
45
46impl<N: Network> ForkedDatabase<N> {
47 pub fn new(backend: SharedBackend<N>, db: BlockchainDb) -> Self {
49 Self {
50 cache_db: CacheDB::new(backend.clone()),
51 backend,
52 db,
53 state_snapshots: Arc::new(Mutex::new(Default::default())),
54 }
55 }
56
57 pub fn database(&self) -> &CacheDB<SharedBackend<N>> {
58 &self.cache_db
59 }
60
61 pub fn database_mut(&mut self) -> &mut CacheDB<SharedBackend<N>> {
62 &mut self.cache_db
63 }
64
65 pub fn state_snapshots(&self) -> &Arc<Mutex<StateSnapshots<ForkDbStateSnapshot<N>>>> {
66 &self.state_snapshots
67 }
68
69 pub fn reset(
71 &mut self,
72 _url: Option<String>,
73 block_number: impl Into<BlockId>,
74 ) -> Result<(), String> {
75 self.backend.set_pinned_block(block_number).map_err(|err| err.to_string())?;
76
77 self.inner().db().clear();
81 self.cache_db = CacheDB::new(self.backend.clone());
83 trace!(target: "backend::forkdb", "Cleared database");
84 Ok(())
85 }
86
87 pub fn flush_cache(&self) {
89 self.db.cache().flush()
90 }
91
92 pub fn inner(&self) -> &BlockchainDb {
94 &self.db
95 }
96
97 pub fn create_state_snapshot(&self) -> ForkDbStateSnapshot<N> {
98 let db = self.db.db();
99 let state_snapshot = StateSnapshot {
100 accounts: db.accounts.read().clone(),
101 storage: db.storage.read().clone(),
102 block_hashes: db.block_hashes.read().clone(),
103 };
104 ForkDbStateSnapshot { local: self.cache_db.clone(), state_snapshot }
105 }
106
107 pub fn insert_state_snapshot(&self) -> U256 {
108 let state_snapshot = self.create_state_snapshot();
109 let mut state_snapshots = self.state_snapshots().lock();
110 let id = state_snapshots.insert(state_snapshot);
111 trace!(target: "backend::forkdb", "Created new snapshot {}", id);
112 id
113 }
114
115 pub fn revert_state_snapshot(&mut self, id: U256, action: RevertStateSnapshotAction) -> bool {
117 let state_snapshot = { self.state_snapshots().lock().remove_at(id) };
118 if let Some(state_snapshot) = state_snapshot {
119 if action.is_keep() {
120 self.state_snapshots().lock().insert_at(state_snapshot.clone(), id);
121 }
122 let ForkDbStateSnapshot {
123 local,
124 state_snapshot: StateSnapshot { accounts, storage, block_hashes },
125 } = state_snapshot;
126 let db = self.inner().db();
127 {
128 let mut accounts_lock = db.accounts.write();
129 accounts_lock.clear();
130 accounts_lock.extend(accounts);
131 }
132 {
133 let mut storage_lock = db.storage.write();
134 storage_lock.clear();
135 storage_lock.extend(storage);
136 }
137 {
138 let mut block_hashes_lock = db.block_hashes.write();
139 block_hashes_lock.clear();
140 block_hashes_lock.extend(block_hashes);
141 }
142
143 self.cache_db = local;
144
145 trace!(target: "backend::forkdb", "Reverted snapshot {}", id);
146 true
147 } else {
148 warn!(target: "backend::forkdb", "No snapshot to revert for {}", id);
149 false
150 }
151 }
152}
153
154impl<N: Network> Database for ForkedDatabase<N> {
155 type Error = DatabaseError;
156
157 fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
158 Database::basic(&mut self.cache_db, address)
162 }
163
164 fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
165 Database::code_by_hash(&mut self.cache_db, code_hash)
166 }
167
168 fn storage(&mut self, address: Address, index: U256) -> Result<U256, Self::Error> {
169 Database::storage(&mut self.cache_db, address, index)
170 }
171
172 fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
173 Database::block_hash(&mut self.cache_db, number)
174 }
175}
176
177impl<N: Network> DatabaseRef for ForkedDatabase<N> {
178 type Error = DatabaseError;
179
180 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
181 self.cache_db.basic_ref(address)
182 }
183
184 fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
185 self.cache_db.code_by_hash_ref(code_hash)
186 }
187
188 fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
189 DatabaseRef::storage_ref(&self.cache_db, address, index)
190 }
191
192 fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
193 self.cache_db.block_hash_ref(number)
194 }
195}
196
197impl<N: Network> DatabaseCommit for ForkedDatabase<N> {
198 fn commit(&mut self, changes: HashMap<Address, Account>) {
199 self.database_mut().commit(changes)
200 }
201}
202
203#[derive(Clone, Debug)]
207pub struct ForkDbStateSnapshot<N: Network> {
208 pub local: CacheDB<SharedBackend<N>>,
209 pub state_snapshot: StateSnapshot,
210}
211
212impl<N: Network> ForkDbStateSnapshot<N> {
213 fn get_storage(&self, address: Address, index: U256) -> Option<U256> {
214 self.local
215 .cache
216 .accounts
217 .get(&address)
218 .and_then(|account| account.storage.get(&index))
219 .copied()
220 }
221}
222
223impl<N: Network> DatabaseRef for ForkDbStateSnapshot<N> {
227 type Error = DatabaseError;
228
229 fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
230 match self.local.cache.accounts.get(&address) {
231 Some(account) => Ok(Some(account.info.clone())),
232 None => {
233 let mut acc = self.state_snapshot.accounts.get(&address).cloned();
234
235 if acc.is_none() {
236 acc = self.local.basic_ref(address)?;
237 }
238 Ok(acc)
239 }
240 }
241 }
242
243 fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
244 self.local.code_by_hash_ref(code_hash)
245 }
246
247 fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
248 match self.local.cache.accounts.get(&address) {
249 Some(account) => match account.storage.get(&index) {
250 Some(entry) => Ok(*entry),
251 None => match self.get_storage(address, index) {
252 None => DatabaseRef::storage_ref(&self.local, address, index),
253 Some(storage) => Ok(storage),
254 },
255 },
256 None => match self.get_storage(address, index) {
257 None => DatabaseRef::storage_ref(&self.local, address, index),
258 Some(storage) => Ok(storage),
259 },
260 }
261 }
262
263 fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
264 match self.state_snapshot.block_hashes.get(&U256::from(number)).copied() {
265 None => self.local.block_hash_ref(number),
266 Some(block_hash) => Ok(block_hash),
267 }
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use crate::backend::BlockchainDbMeta;
275 use foundry_common::provider::get_http_provider;
276
277 #[tokio::test(flavor = "multi_thread")]
280 async fn fork_db_insert_basic_default() {
281 let rpc = foundry_test_utils::rpc::next_http_rpc_endpoint();
282 let provider = get_http_provider(rpc.clone());
283 let meta = BlockchainDbMeta::new(Default::default(), rpc);
284
285 let db = BlockchainDb::new(meta, None);
286
287 let backend = SharedBackend::spawn_backend(Arc::new(provider), db.clone(), None).await;
288
289 let mut db = ForkedDatabase::new(backend, db);
290 let address = Address::random();
291
292 let info = Database::basic(&mut db, address).unwrap();
293 assert!(info.is_some());
294 let mut info = info.unwrap();
295 info.balance = U256::from(500u64);
296
297 db.database_mut().insert_account_info(address, info.clone());
299
300 let loaded = Database::basic(&mut db, address).unwrap();
301 assert!(loaded.is_some());
302 assert_eq!(loaded.unwrap(), info);
303 }
304}