1use crate::{
4 eth::backend::db::{
5 Db, MaybeForkedDatabase, MaybeFullDatabase, SerializableAccountRecord, SerializableBlock,
6 SerializableHistoricalStates, SerializableState, SerializableTransaction, StateDb,
7 },
8 mem::state::state_root,
9};
10use alloy_primitives::{map::HashMap, Address, B256, U256};
11use alloy_rpc_types::BlockId;
12use foundry_evm::backend::{BlockchainDb, DatabaseResult, StateSnapshot};
13use revm::{
14 context::BlockEnv,
15 database::{DatabaseRef, DbAccount},
16 state::AccountInfo,
17};
18
19pub use foundry_evm::backend::MemDb;
21use foundry_evm::backend::RevertStateSnapshotAction;
22
23impl Db for MemDb {
24 fn insert_account(&mut self, address: Address, account: AccountInfo) {
25 self.inner.insert_account_info(address, account)
26 }
27
28 fn set_storage_at(&mut self, address: Address, slot: B256, val: B256) -> DatabaseResult<()> {
29 self.inner.insert_account_storage(address, slot.into(), val.into())
30 }
31
32 fn insert_block_hash(&mut self, number: U256, hash: B256) {
33 self.inner.cache.block_hashes.insert(number, hash);
34 }
35
36 fn dump_state(
37 &self,
38 at: BlockEnv,
39 best_number: u64,
40 blocks: Vec<SerializableBlock>,
41 transactions: Vec<SerializableTransaction>,
42 historical_states: Option<SerializableHistoricalStates>,
43 ) -> DatabaseResult<Option<SerializableState>> {
44 let accounts = self
45 .inner
46 .cache
47 .accounts
48 .clone()
49 .into_iter()
50 .map(|(k, v)| -> DatabaseResult<_> {
51 let code = if let Some(code) = v.info.code {
52 code
53 } else {
54 self.inner.code_by_hash_ref(v.info.code_hash)?
55 };
56 Ok((
57 k,
58 SerializableAccountRecord {
59 nonce: v.info.nonce,
60 balance: v.info.balance,
61 code: code.original_bytes(),
62 storage: v.storage.into_iter().map(|(k, v)| (k.into(), v.into())).collect(),
63 },
64 ))
65 })
66 .collect::<Result<_, _>>()?;
67
68 Ok(Some(SerializableState {
69 block: Some(at),
70 accounts,
71 best_block_number: Some(best_number),
72 blocks,
73 transactions,
74 historical_states,
75 }))
76 }
77
78 fn snapshot_state(&mut self) -> U256 {
80 let id = self.state_snapshots.insert(self.inner.clone());
81 trace!(target: "backend::memdb", "Created new state snapshot {}", id);
82 id
83 }
84
85 fn revert_state(&mut self, id: U256, action: RevertStateSnapshotAction) -> bool {
86 if let Some(state_snapshot) = self.state_snapshots.remove(id) {
87 if action.is_keep() {
88 self.state_snapshots.insert_at(state_snapshot.clone(), id);
89 }
90 self.inner = state_snapshot;
91 trace!(target: "backend::memdb", "Reverted state snapshot {}", id);
92 true
93 } else {
94 warn!(target: "backend::memdb", "No state snapshot to revert for {}", id);
95 false
96 }
97 }
98
99 fn maybe_state_root(&self) -> Option<B256> {
100 Some(state_root(&self.inner.cache.accounts))
101 }
102
103 fn current_state(&self) -> StateDb {
104 StateDb::new(Self { inner: self.inner.clone(), ..Default::default() })
105 }
106}
107
108impl MaybeFullDatabase for MemDb {
109 fn as_dyn(&self) -> &dyn DatabaseRef<Error = foundry_evm::backend::DatabaseError> {
110 self
111 }
112
113 fn maybe_as_full_db(&self) -> Option<&HashMap<Address, DbAccount>> {
114 Some(&self.inner.cache.accounts)
115 }
116
117 fn clear_into_state_snapshot(&mut self) -> StateSnapshot {
118 self.inner.clear_into_state_snapshot()
119 }
120
121 fn read_as_state_snapshot(&self) -> StateSnapshot {
122 self.inner.read_as_state_snapshot()
123 }
124
125 fn clear(&mut self) {
126 self.inner.clear();
127 }
128
129 fn init_from_state_snapshot(&mut self, snapshot: StateSnapshot) {
130 self.inner.init_from_state_snapshot(snapshot)
131 }
132}
133
134impl MaybeForkedDatabase for MemDb {
135 fn maybe_reset(&mut self, _url: Option<String>, _block_number: BlockId) -> Result<(), String> {
136 Err("not supported".to_string())
137 }
138
139 fn maybe_flush_cache(&self) -> Result<(), String> {
140 Err("not supported".to_string())
141 }
142
143 fn maybe_inner(&self) -> Result<&BlockchainDb, String> {
144 Err("not supported".to_string())
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use alloy_primitives::{address, Bytes};
152 use revm::{bytecode::Bytecode, primitives::KECCAK_EMPTY};
153 use std::collections::BTreeMap;
154
155 #[test]
158 fn test_dump_reload_cycle() {
159 let test_addr: Address = address!("0xf39fd6e51aad88f6f4ce6ab8827279cfffb92266");
160
161 let mut dump_db = MemDb::default();
162
163 let contract_code = Bytecode::new_raw(Bytes::from("fake contract code"));
164 dump_db.insert_account(
165 test_addr,
166 AccountInfo {
167 balance: U256::from(123456),
168 code_hash: KECCAK_EMPTY,
169 code: Some(contract_code.clone()),
170 nonce: 1234,
171 },
172 );
173 dump_db
174 .set_storage_at(test_addr, U256::from(1234567).into(), U256::from(1).into())
175 .unwrap();
176
177 let state = dump_db
179 .dump_state(Default::default(), 0, Vec::new(), Vec::new(), Default::default())
180 .unwrap()
181 .unwrap();
182
183 let mut load_db = MemDb::default();
184
185 load_db.load_state(state).unwrap();
186
187 let loaded_account = load_db.basic_ref(test_addr).unwrap().unwrap();
188
189 assert_eq!(loaded_account.balance, U256::from(123456));
190 assert_eq!(load_db.code_by_hash_ref(loaded_account.code_hash).unwrap(), contract_code);
191 assert_eq!(loaded_account.nonce, 1234);
192 assert_eq!(load_db.storage_ref(test_addr, U256::from(1234567)).unwrap(), U256::from(1));
193 }
194
195 #[test]
198 fn test_load_state_merge() {
199 let test_addr: Address = address!("0xf39fd6e51aad88f6f4ce6ab8827279cfffb92266");
200 let test_addr2: Address = address!("0x70997970c51812dc3a010c7d01b50e0d17dc79c8");
201
202 let contract_code = Bytecode::new_raw(Bytes::from("fake contract code"));
203
204 let mut db = MemDb::default();
205
206 db.insert_account(
207 test_addr,
208 AccountInfo {
209 balance: U256::from(123456),
210 code_hash: KECCAK_EMPTY,
211 code: Some(contract_code.clone()),
212 nonce: 1234,
213 },
214 );
215
216 db.set_storage_at(test_addr, U256::from(1234567).into(), U256::from(1).into()).unwrap();
217 db.set_storage_at(test_addr, U256::from(1234568).into(), U256::from(2).into()).unwrap();
218
219 let mut new_state = SerializableState::default();
220
221 new_state.accounts.insert(
222 test_addr2,
223 SerializableAccountRecord {
224 balance: Default::default(),
225 code: Default::default(),
226 nonce: 1,
227 storage: Default::default(),
228 },
229 );
230
231 let mut new_storage = BTreeMap::default();
232 new_storage.insert(U256::from(1234568).into(), U256::from(5).into());
233
234 new_state.accounts.insert(
235 test_addr,
236 SerializableAccountRecord {
237 balance: U256::from(100100),
238 code: contract_code.bytes()[..contract_code.len()].to_vec().into(),
239 nonce: 100,
240 storage: new_storage,
241 },
242 );
243
244 db.load_state(new_state).unwrap();
245
246 let loaded_account = db.basic_ref(test_addr).unwrap().unwrap();
247 let loaded_account2 = db.basic_ref(test_addr2).unwrap().unwrap();
248
249 assert_eq!(loaded_account2.nonce, 1);
250
251 assert_eq!(loaded_account.balance, U256::from(100100));
252 assert_eq!(db.code_by_hash_ref(loaded_account.code_hash).unwrap(), contract_code);
253 assert_eq!(loaded_account.nonce, 1234);
254 assert_eq!(db.storage_ref(test_addr, U256::from(1234567)).unwrap(), U256::from(1));
255 assert_eq!(db.storage_ref(test_addr, U256::from(1234568)).unwrap(), U256::from(5));
256 }
257}