diff --git a/crates/cheatcodes/src/inspector.rs b/crates/cheatcodes/src/inspector.rs index 8c4b6441c74e4..de918b9b3d664 100644 --- a/crates/cheatcodes/src/inspector.rs +++ b/crates/cheatcodes/src/inspector.rs @@ -1256,6 +1256,11 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes { // Handle assume not revert cheatcode. if let Some(assume_no_revert) = &mut self.assume_no_revert { + // Record current reverter address before processing the expect revert if call reverted, + // expect revert is set with expected reverter address and no actual reverter set yet. + if outcome.result.is_revert() && assume_no_revert.reverted_by.is_none() { + assume_no_revert.reverted_by = Some(call.target_address); + } // allow multiple cheatcode calls at the same depth if ecx.journaled_state.depth() <= assume_no_revert.depth && !cheatcode_call { // Discard run if we're at the same depth as cheatcode, call reverted, and no @@ -1267,7 +1272,7 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes { outcome.result.result, &outcome.result.output, &self.config.available_artifacts, - Some(&call.target_address), + assume_no_revert.reverted_by.as_ref(), ) { // if result is Ok, it was an anticipated revert; return an "assume" error // to reject this run @@ -1294,6 +1299,14 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes { // Handle expected reverts. if let Some(expected_revert) = &mut self.expected_revert { + // Record current reverter address before processing the expect revert if call reverted, + // expect revert is set with expected reverter address and no actual reverter set yet. + if outcome.result.is_revert() && + expected_revert.reverter.is_some() && + expected_revert.reverted_by.is_none() + { + expected_revert.reverted_by = Some(call.target_address); + } if ecx.journaled_state.depth() <= expected_revert.depth { let needs_processing = match expected_revert.kind { ExpectedRevertKind::Default => !cheatcode_call, @@ -1313,7 +1326,7 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes { outcome.result.result, outcome.result.output.clone(), &self.config.available_artifacts, - Some(&call.target_address), + expected_revert.reverted_by.as_ref(), ) { Err(error) => { trace!(expected=?expected_revert, ?error, status=?outcome.result.result, "Expected revert mismatch"); diff --git a/crates/cheatcodes/src/test/assume.rs b/crates/cheatcodes/src/test/assume.rs index 3da9d730ad0b4..d3ab4cacd760c 100644 --- a/crates/cheatcodes/src/test/assume.rs +++ b/crates/cheatcodes/src/test/assume.rs @@ -11,6 +11,10 @@ use std::fmt::Debug; use super::revert::{handle_revert, RevertParameters}; +pub const ASSUME_EXPECT_REJECT_MAGIC: &str = "Cannot combine an assumeNoRevert with expectRevert"; +pub const ASSUME_REJECT_MAGIC: &str = + "Cannot combine a generic assumeNoRevert with specific assumeNoRevert reasons"; + #[derive(Clone, Debug)] pub struct AssumeNoRevert { /// The call depth at which the cheatcode was added. @@ -19,6 +23,8 @@ pub struct AssumeNoRevert { /// reverts with parameters not specified here will count as normal reverts and not rejects /// towards the counter. pub reasons: Option>, + /// Address that reverted the call. + pub reverted_by: Option
, } /// Parameters for a single anticipated revert, to be thrown out if encountered. @@ -30,8 +36,6 @@ pub struct AcceptableRevertParameters { pub partial_match: bool, /// Contract expected to revert next call. pub reverter: Option
, - /// Actual reverter of the call. - pub reverted_by: Option
, } impl RevertParameters for AcceptableRevertParameters { @@ -147,38 +151,30 @@ fn assume_no_revert( partial_match: bool, reverter: Option
, ) -> Result { - ensure!(state.expected_revert.is_none(), ""); + ensure!(state.expected_revert.is_none(), ASSUME_EXPECT_REJECT_MAGIC); // if assume_no_revert is not set, set it if state.assume_no_revert.is_none() { - state.assume_no_revert = Some(AssumeNoRevert { depth, reasons: None }); + state.assume_no_revert = Some(AssumeNoRevert { depth, reasons: None, reverted_by: None }); // if reason is not none, create a new AssumeNoRevertParams vec if let Some(reason) = reason { state.assume_no_revert.as_mut().unwrap().reasons = - Some(vec![AcceptableRevertParameters { - reason, - partial_match, - reverter, - reverted_by: None, - }]); + Some(vec![AcceptableRevertParameters { reason, partial_match, reverter }]); } } else { // otherwise, ensure that reasons vec is not none and new reason is also not none let valid_assume = state.assume_no_revert.as_ref().unwrap().reasons.is_some() && reason.is_some(); - ensure!( - valid_assume, - "cannot combine a generic assumeNoRevert with specific assumeNoRevert reasons" - ); + ensure!(valid_assume, ASSUME_REJECT_MAGIC); // and append the new reason - state.assume_no_revert.as_mut().unwrap().reasons.as_mut().unwrap().push( - AcceptableRevertParameters { - reason: reason.unwrap(), - partial_match, - reverter, - reverted_by: None, - }, - ); + state + .assume_no_revert + .as_mut() + .unwrap() + .reasons + .as_mut() + .unwrap() + .push(AcceptableRevertParameters { reason: reason.unwrap(), partial_match, reverter }); } Ok(Default::default()) @@ -190,16 +186,30 @@ pub(crate) fn handle_assume_no_revert( retdata: &Bytes, known_contracts: &Option, reverter: Option<&Address>, -) -> Result<(), Error> { +) -> Result<()> { // iterate over acceptable reasons and try to match against any, otherwise, return an Error with // the revert data - assume_no_revert - .reasons - .as_ref() - .and_then(|reasons| { - reasons.iter().find_map(|reason| { - handle_revert(false, reason, status, retdata, known_contracts, reverter).ok() - }) - }) - .ok_or_else(|| retdata.clone().into()) + assume_no_revert.reasons.as_ref().map_or_else( + || { + // todo: fix this hack to get cheatcode name in error message + let retdata_str = retdata.to_string(); + if retdata_str.contains(ASSUME_REJECT_MAGIC) || + retdata_str.contains(ASSUME_EXPECT_REJECT_MAGIC) + { + // raise error with retdata as a string, so apply_dispatch will insert the cheatcode + // name + Err(Error::from(String::from_utf8(retdata.to_vec()).unwrap())) + } else { + Ok(()) + } + }, + |reasons| { + reasons + .iter() + .find_map(|reason| { + handle_revert(false, reason, status, retdata, known_contracts, reverter).ok() + }) + .ok_or_else(|| retdata.clone().into()) + }, + ) } diff --git a/crates/cheatcodes/src/test/expect.rs b/crates/cheatcodes/src/test/expect.rs index db2b7a0cf6f9f..94f50cc97fd17 100644 --- a/crates/cheatcodes/src/test/expect.rs +++ b/crates/cheatcodes/src/test/expect.rs @@ -84,6 +84,8 @@ pub struct ExpectedRevert { pub partial_match: bool, /// Contract expected to revert next call. pub reverter: Option
, + /// Address that reverted the call. + pub reverted_by: Option
, } #[derive(Clone, Debug)] @@ -689,7 +691,7 @@ fn expect_revert( }, partial_match, reverter, - // reverted_by: None, + reverted_by: None, }); Ok(Default::default()) }