use crate::{
backend::{RevertStateSnapshotAction, StateSnapshot},
state_snapshot::StateSnapshots,
};
use alloy_primitives::{map::HashMap, Address, B256, U256};
use alloy_rpc_types::BlockId;
use foundry_fork_db::{BlockchainDb, DatabaseError, SharedBackend};
use parking_lot::Mutex;
use revm::{
db::{CacheDB, DatabaseRef},
primitives::{Account, AccountInfo, Bytecode},
Database, DatabaseCommit,
};
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct ForkedDatabase {
backend: SharedBackend,
cache_db: CacheDB<SharedBackend>,
db: BlockchainDb,
state_snapshots: Arc<Mutex<StateSnapshots<ForkDbStateSnapshot>>>,
}
impl ForkedDatabase {
pub fn new(backend: SharedBackend, db: BlockchainDb) -> Self {
Self {
cache_db: CacheDB::new(backend.clone()),
backend,
db,
state_snapshots: Arc::new(Mutex::new(Default::default())),
}
}
pub fn database(&self) -> &CacheDB<SharedBackend> {
&self.cache_db
}
pub fn database_mut(&mut self) -> &mut CacheDB<SharedBackend> {
&mut self.cache_db
}
pub fn state_snapshots(&self) -> &Arc<Mutex<StateSnapshots<ForkDbStateSnapshot>>> {
&self.state_snapshots
}
pub fn reset(
&mut self,
_url: Option<String>,
block_number: impl Into<BlockId>,
) -> Result<(), String> {
self.backend.set_pinned_block(block_number).map_err(|err| err.to_string())?;
self.inner().db().clear();
self.cache_db = CacheDB::new(self.backend.clone());
trace!(target: "backend::forkdb", "Cleared database");
Ok(())
}
pub fn flush_cache(&self) {
self.db.cache().flush()
}
pub fn inner(&self) -> &BlockchainDb {
&self.db
}
pub fn create_state_snapshot(&self) -> ForkDbStateSnapshot {
let db = self.db.db();
let state_snapshot = StateSnapshot {
accounts: db.accounts.read().clone(),
storage: db.storage.read().clone(),
block_hashes: db.block_hashes.read().clone(),
};
ForkDbStateSnapshot { local: self.cache_db.clone(), state_snapshot }
}
pub fn insert_state_snapshot(&self) -> U256 {
let state_snapshot = self.create_state_snapshot();
let mut state_snapshots = self.state_snapshots().lock();
let id = state_snapshots.insert(state_snapshot);
trace!(target: "backend::forkdb", "Created new snapshot {}", id);
id
}
pub fn revert_state_snapshot(&mut self, id: U256, action: RevertStateSnapshotAction) -> bool {
let state_snapshot = { self.state_snapshots().lock().remove_at(id) };
if let Some(state_snapshot) = state_snapshot {
if action.is_keep() {
self.state_snapshots().lock().insert_at(state_snapshot.clone(), id);
}
let ForkDbStateSnapshot {
local,
state_snapshot: StateSnapshot { accounts, storage, block_hashes },
} = state_snapshot;
let db = self.inner().db();
{
let mut accounts_lock = db.accounts.write();
accounts_lock.clear();
accounts_lock.extend(accounts);
}
{
let mut storage_lock = db.storage.write();
storage_lock.clear();
storage_lock.extend(storage);
}
{
let mut block_hashes_lock = db.block_hashes.write();
block_hashes_lock.clear();
block_hashes_lock.extend(block_hashes);
}
self.cache_db = local;
trace!(target: "backend::forkdb", "Reverted snapshot {}", id);
true
} else {
warn!(target: "backend::forkdb", "No snapshot to revert for {}", id);
false
}
}
}
impl Database for ForkedDatabase {
type Error = DatabaseError;
fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
Database::basic(&mut self.cache_db, address)
}
fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
Database::code_by_hash(&mut self.cache_db, code_hash)
}
fn storage(&mut self, address: Address, index: U256) -> Result<U256, Self::Error> {
Database::storage(&mut self.cache_db, address, index)
}
fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
Database::block_hash(&mut self.cache_db, number)
}
}
impl DatabaseRef for ForkedDatabase {
type Error = DatabaseError;
fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
self.cache_db.basic_ref(address)
}
fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
self.cache_db.code_by_hash_ref(code_hash)
}
fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
DatabaseRef::storage_ref(&self.cache_db, address, index)
}
fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
self.cache_db.block_hash_ref(number)
}
}
impl DatabaseCommit for ForkedDatabase {
fn commit(&mut self, changes: HashMap<Address, Account>) {
self.database_mut().commit(changes)
}
}
#[derive(Clone, Debug)]
pub struct ForkDbStateSnapshot {
pub local: CacheDB<SharedBackend>,
pub state_snapshot: StateSnapshot,
}
impl ForkDbStateSnapshot {
fn get_storage(&self, address: Address, index: U256) -> Option<U256> {
self.local.accounts.get(&address).and_then(|account| account.storage.get(&index)).copied()
}
}
impl DatabaseRef for ForkDbStateSnapshot {
type Error = DatabaseError;
fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
match self.local.accounts.get(&address) {
Some(account) => Ok(Some(account.info.clone())),
None => {
let mut acc = self.state_snapshot.accounts.get(&address).cloned();
if acc.is_none() {
acc = self.local.basic_ref(address)?;
}
Ok(acc)
}
}
}
fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
self.local.code_by_hash_ref(code_hash)
}
fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
match self.local.accounts.get(&address) {
Some(account) => match account.storage.get(&index) {
Some(entry) => Ok(*entry),
None => match self.get_storage(address, index) {
None => DatabaseRef::storage_ref(&self.local, address, index),
Some(storage) => Ok(storage),
},
},
None => match self.get_storage(address, index) {
None => DatabaseRef::storage_ref(&self.local, address, index),
Some(storage) => Ok(storage),
},
}
}
fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
match self.state_snapshot.block_hashes.get(&U256::from(number)).copied() {
None => self.local.block_hash_ref(number),
Some(block_hash) => Ok(block_hash),
}
}
}
#[cfg(test)]
#[allow(clippy::needless_return)]
mod tests {
use super::*;
use crate::backend::BlockchainDbMeta;
use foundry_common::provider::get_http_provider;
use std::collections::BTreeSet;
#[tokio::test(flavor = "multi_thread")]
async fn fork_db_insert_basic_default() {
let rpc = foundry_test_utils::rpc::next_http_rpc_endpoint();
let provider = get_http_provider(rpc.clone());
let meta = BlockchainDbMeta {
cfg_env: Default::default(),
block_env: Default::default(),
hosts: BTreeSet::from([rpc]),
};
let db = BlockchainDb::new(meta, None);
let backend = SharedBackend::spawn_backend(Arc::new(provider), db.clone(), None).await;
let mut db = ForkedDatabase::new(backend, db);
let address = Address::random();
let info = Database::basic(&mut db, address).unwrap();
assert!(info.is_some());
let mut info = info.unwrap();
info.balance = U256::from(500u64);
db.database_mut().insert_account_info(address, info.clone());
let loaded = Database::basic(&mut db, address).unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap(), info);
}
}