foundry_common/
abi.rs

1//! ABI related helper functions.
2
3use alloy_dyn_abi::{DynSolType, DynSolValue, FunctionExt, JsonAbiExt};
4use alloy_json_abi::{Error, Event, Function, JsonAbi, Param};
5use alloy_primitives::{Address, LogData, hex};
6use eyre::{Context, ContextCompat, Result};
7use foundry_block_explorers::{Client, contract::ContractMetadata, errors::EtherscanError};
8use foundry_config::Chain;
9use std::pin::Pin;
10
11pub fn encode_args<I, S>(inputs: &[Param], args: I) -> Result<Vec<DynSolValue>>
12where
13    I: IntoIterator<Item = S>,
14    S: AsRef<str>,
15{
16    let args: Vec<S> = args.into_iter().collect();
17
18    if inputs.len() != args.len() {
19        eyre::bail!("encode length mismatch: expected {} types, got {}", inputs.len(), args.len())
20    }
21
22    std::iter::zip(inputs, args)
23        .map(|(input, arg)| coerce_value(&input.selector_type(), arg.as_ref()))
24        .collect()
25}
26
27/// Given a function and a vector of string arguments, it proceeds to convert the args to alloy
28/// [DynSolValue]s and then ABI encode them, prefixes the encoded data with the function selector.
29pub fn encode_function_args<I, S>(func: &Function, args: I) -> Result<Vec<u8>>
30where
31    I: IntoIterator<Item = S>,
32    S: AsRef<str>,
33{
34    Ok(func.abi_encode_input(&encode_args(&func.inputs, args)?)?)
35}
36
37/// Given a function and a vector of string arguments, it proceeds to convert the args to alloy
38/// [DynSolValue]s and then ABI encode them. Doesn't prefix the function selector.
39pub fn encode_function_args_raw<I, S>(func: &Function, args: I) -> Result<Vec<u8>>
40where
41    I: IntoIterator<Item = S>,
42    S: AsRef<str>,
43{
44    Ok(func.abi_encode_input_raw(&encode_args(&func.inputs, args)?)?)
45}
46
47/// Given a function and a vector of string arguments, it proceeds to convert the args to alloy
48/// [DynSolValue]s and encode them using the packed encoding.
49pub fn encode_function_args_packed<I, S>(func: &Function, args: I) -> Result<Vec<u8>>
50where
51    I: IntoIterator<Item = S>,
52    S: AsRef<str>,
53{
54    let args: Vec<S> = args.into_iter().collect();
55
56    if func.inputs.len() != args.len() {
57        eyre::bail!(
58            "encode length mismatch: expected {} types, got {}",
59            func.inputs.len(),
60            args.len(),
61        );
62    }
63
64    let params: Vec<Vec<u8>> = std::iter::zip(&func.inputs, args)
65        .map(|(input, arg)| coerce_value(&input.selector_type(), arg.as_ref()))
66        .collect::<Result<Vec<_>>>()?
67        .into_iter()
68        .map(|v| v.abi_encode_packed())
69        .collect();
70
71    Ok(params.concat())
72}
73
74/// Decodes the calldata of the function
75pub fn abi_decode_calldata(
76    sig: &str,
77    calldata: &str,
78    input: bool,
79    fn_selector: bool,
80) -> Result<Vec<DynSolValue>> {
81    let func = get_func(sig)?;
82    let calldata = hex::decode(calldata)?;
83
84    let mut calldata = calldata.as_slice();
85    // If function selector is prefixed in "calldata", remove it (first 4 bytes)
86    if input && fn_selector && calldata.len() >= 4 {
87        calldata = &calldata[4..];
88    }
89
90    let res =
91        if input { func.abi_decode_input(calldata) } else { func.abi_decode_output(calldata) }?;
92
93    // in case the decoding worked but nothing was decoded
94    if res.is_empty() {
95        eyre::bail!("no data was decoded")
96    }
97
98    Ok(res)
99}
100
101/// Given a function signature string, it tries to parse it as a `Function`
102pub fn get_func(sig: &str) -> Result<Function> {
103    Function::parse(sig).wrap_err("could not parse function signature")
104}
105
106/// Given an event signature string, it tries to parse it as a `Event`
107pub fn get_event(sig: &str) -> Result<Event> {
108    Event::parse(sig).wrap_err("could not parse event signature")
109}
110
111/// Given an error signature string, it tries to parse it as a `Error`
112pub fn get_error(sig: &str) -> Result<Error> {
113    Error::parse(sig).wrap_err("could not parse error signature")
114}
115
116/// Given an event without indexed parameters and a rawlog, it tries to return the event with the
117/// proper indexed parameters. Otherwise, it returns the original event.
118pub fn get_indexed_event(mut event: Event, raw_log: &LogData) -> Event {
119    if !event.anonymous && raw_log.topics().len() > 1 {
120        let indexed_params = raw_log.topics().len() - 1;
121        let num_inputs = event.inputs.len();
122        let num_address_params = event.inputs.iter().filter(|p| p.ty == "address").count();
123
124        event.inputs.iter_mut().enumerate().for_each(|(index, param)| {
125            if param.name.is_empty() {
126                param.name = format!("param{index}");
127            }
128            if num_inputs == indexed_params
129                || (num_address_params == indexed_params && param.ty == "address")
130            {
131                param.indexed = true;
132            }
133        })
134    }
135    event
136}
137
138/// Fetches the ABI of a contract from Etherscan.
139pub async fn fetch_abi_from_etherscan(
140    address: Address,
141    config: &foundry_config::Config,
142) -> Result<Vec<(JsonAbi, String)>> {
143    let chain = config.chain.unwrap_or_default();
144    let api_key = config.get_etherscan_api_key(Some(chain)).unwrap_or_default();
145    let client = Client::new(chain, api_key)?;
146    let source = client.contract_source_code(address).await?;
147    source.items.into_iter().map(|item| Ok((item.abi()?, item.contract_name))).collect()
148}
149
150/// Given a function name, address, and args, tries to parse it as a `Function` by fetching the
151/// abi from etherscan. If the address is a proxy, fetches the ABI of the implementation contract.
152pub async fn get_func_etherscan(
153    function_name: &str,
154    contract: Address,
155    args: &[String],
156    chain: Chain,
157    etherscan_api_key: &str,
158) -> Result<Function> {
159    let client = Client::new(chain, etherscan_api_key)?;
160    let source = find_source(client, contract).await?;
161    let metadata = source.items.first().wrap_err("etherscan returned empty metadata")?;
162
163    let mut abi = metadata.abi()?;
164    let funcs = abi.functions.remove(function_name).unwrap_or_default();
165
166    for func in funcs {
167        let res = encode_function_args(&func, args);
168        if res.is_ok() {
169            return Ok(func);
170        }
171    }
172
173    Err(eyre::eyre!("Function not found in abi"))
174}
175
176/// If the code at `address` is a proxy, recurse until we find the implementation.
177pub fn find_source(
178    client: Client,
179    address: Address,
180) -> Pin<Box<dyn Future<Output = Result<ContractMetadata>>>> {
181    Box::pin(async move {
182        trace!(%address, "find Etherscan source");
183        let source = client.contract_source_code(address).await?;
184        let metadata = source.items.first().wrap_err("Etherscan returned no data")?;
185        if metadata.proxy == 0 {
186            Ok(source)
187        } else {
188            let implementation = metadata.implementation.unwrap();
189            sh_println!(
190                "Contract at {address} is a proxy, trying to fetch source at {implementation}..."
191            )?;
192            match find_source(client, implementation).await {
193                impl_source @ Ok(_) => impl_source,
194                Err(e) => {
195                    let err = EtherscanError::ContractCodeNotVerified(address).to_string();
196                    if e.to_string() == err {
197                        error!(%err);
198                        Ok(source)
199                    } else {
200                        Err(e)
201                    }
202                }
203            }
204        }
205    })
206}
207
208/// Helper function to coerce a value to a [DynSolValue] given a type string
209pub fn coerce_value(ty: &str, arg: &str) -> Result<DynSolValue> {
210    let ty = DynSolType::parse(ty)?;
211    Ok(DynSolType::coerce_str(&ty, arg)?)
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use alloy_dyn_abi::EventExt;
218    use alloy_primitives::{B256, U256};
219
220    #[test]
221    fn test_get_func() {
222        let func = get_func("function foo(uint256 a, uint256 b) returns (uint256)");
223        assert!(func.is_ok());
224        let func = func.unwrap();
225        assert_eq!(func.name, "foo");
226        assert_eq!(func.inputs.len(), 2);
227        assert_eq!(func.inputs[0].ty, "uint256");
228        assert_eq!(func.inputs[1].ty, "uint256");
229
230        // Stripped down function, which [Function] can parse.
231        let func = get_func("foo(bytes4 a, uint8 b)(bytes4)");
232        assert!(func.is_ok());
233        let func = func.unwrap();
234        assert_eq!(func.name, "foo");
235        assert_eq!(func.inputs.len(), 2);
236        assert_eq!(func.inputs[0].ty, "bytes4");
237        assert_eq!(func.inputs[1].ty, "uint8");
238        assert_eq!(func.outputs[0].ty, "bytes4");
239    }
240
241    #[test]
242    fn test_indexed_only_address() {
243        let event = get_event("event Ev(address,uint256,address)").unwrap();
244
245        let param0 = B256::random();
246        let param1 = vec![3; 32];
247        let param2 = B256::random();
248        let log = LogData::new_unchecked(vec![event.selector(), param0, param2], param1.into());
249        let event = get_indexed_event(event, &log);
250
251        assert_eq!(event.inputs.len(), 3);
252
253        // Only the address fields get indexed since total_params > num_indexed_params
254        let parsed = event.decode_log(&log).unwrap();
255
256        assert_eq!(event.inputs.iter().filter(|param| param.indexed).count(), 2);
257        assert_eq!(parsed.indexed[0], DynSolValue::Address(Address::from_word(param0)));
258        assert_eq!(parsed.body[0], DynSolValue::Uint(U256::from_be_bytes([3; 32]), 256));
259        assert_eq!(parsed.indexed[1], DynSolValue::Address(Address::from_word(param2)));
260    }
261
262    #[test]
263    fn test_indexed_all() {
264        let event = get_event("event Ev(address,uint256,address)").unwrap();
265
266        let param0 = B256::random();
267        let param1 = vec![3; 32];
268        let param2 = B256::random();
269        let log = LogData::new_unchecked(
270            vec![event.selector(), param0, B256::from_slice(&param1), param2],
271            vec![].into(),
272        );
273        let event = get_indexed_event(event, &log);
274
275        assert_eq!(event.inputs.len(), 3);
276
277        // All parameters get indexed since num_indexed_params == total_params
278        assert_eq!(event.inputs.iter().filter(|param| param.indexed).count(), 3);
279        let parsed = event.decode_log(&log).unwrap();
280
281        assert_eq!(parsed.indexed[0], DynSolValue::Address(Address::from_word(param0)));
282        assert_eq!(parsed.indexed[1], DynSolValue::Uint(U256::from_be_bytes([3; 32]), 256));
283        assert_eq!(parsed.indexed[2], DynSolValue::Address(Address::from_word(param2)));
284    }
285
286    #[test]
287    fn test_encode_args_length_validation() {
288        use alloy_json_abi::Param;
289
290        let params = vec![
291            Param {
292                name: "a".to_string(),
293                ty: "uint256".to_string(),
294                internal_type: None,
295                components: vec![],
296            },
297            Param {
298                name: "b".to_string(),
299                ty: "address".to_string(),
300                internal_type: None,
301                components: vec![],
302            },
303        ];
304
305        // Less arguments than parameters
306        let args = vec!["1"];
307        let res = encode_args(&params, &args);
308        assert!(res.is_err());
309        assert!(format!("{}", res.unwrap_err()).contains("encode length mismatch"));
310
311        // Exact number of arguments and parameters
312        let args = vec!["1", "0x0000000000000000000000000000000000000001"];
313        let res = encode_args(&params, &args);
314        assert!(res.is_ok());
315        let values = res.unwrap();
316        assert_eq!(values.len(), 2);
317
318        // More arguments than parameters
319        let args = vec!["1", "0x0000000000000000000000000000000000000001", "extra"];
320        let res = encode_args(&params, &args);
321        assert!(res.is_err());
322        assert!(format!("{}", res.unwrap_err()).contains("encode length mismatch"));
323    }
324}