From f3130a52937039a7e82b4aef07143139c322d09f Mon Sep 17 00:00:00 2001 From: anukul <44864521+anukul@users.noreply.github.com> Date: Thu, 20 Feb 2025 01:03:15 -0800 Subject: [PATCH] feat(cheatcodes): add `expectCreate` and `expectCreate2` (#9875) * add expectCreate and expectCreate2 cheatcodes * add tests * apply clippy fixes * apply clippy fixes * fix failing test * fix failing test * fix failing test * fix failing test: use line wildcards * add requested changes * move nested creates to single test * Fix test --------- Co-authored-by: zerosnacks <95942363+zerosnacks@users.noreply.github.com> Co-authored-by: grandizzy <38490174+grandizzy@users.noreply.github.com> Co-authored-by: grandizzy --- crates/cheatcodes/assets/cheatcodes.json | 40 +++++++++++ crates/cheatcodes/spec/src/vm.rs | 8 +++ crates/cheatcodes/src/inspector.rs | 58 ++++++++++++++-- crates/cheatcodes/src/test/expect.rs | 66 ++++++++++++++++++- crates/forge/tests/cli/failure_assertions.rs | 24 +++++++ .../tests/fixtures/ExpectCreateFailures.t.sol | 62 +++++++++++++++++ testdata/cheats/Vm.sol | 2 + testdata/default/cheats/ExpectCreate.t.sol | 44 +++++++++++++ 8 files changed, 296 insertions(+), 8 deletions(-) create mode 100644 crates/forge/tests/fixtures/ExpectCreateFailures.t.sol create mode 100644 testdata/default/cheats/ExpectCreate.t.sol diff --git a/crates/cheatcodes/assets/cheatcodes.json b/crates/cheatcodes/assets/cheatcodes.json index 3d528b903bf1..50189117bd26 100644 --- a/crates/cheatcodes/assets/cheatcodes.json +++ b/crates/cheatcodes/assets/cheatcodes.json @@ -4872,6 +4872,46 @@ "status": "stable", "safety": "unsafe" }, + { + "func": { + "id": "expectCreate", + "description": "Expects the deployment of the specified bytecode by the specified address using the CREATE opcode", + "declaration": "function expectCreate(bytes calldata bytecode, address deployer) external;", + "visibility": "external", + "mutability": "", + "signature": "expectCreate(bytes,address)", + "selector": "0x73cdce36", + "selectorBytes": [ + 115, + 205, + 206, + 54 + ] + }, + "group": "testing", + "status": "stable", + "safety": "unsafe" + }, + { + "func": { + "id": "expectCreate2", + "description": "Expects the deployment of the specified bytecode by the specified address using the CREATE2 opcode", + "declaration": "function expectCreate2(bytes calldata bytecode, address deployer) external;", + "visibility": "external", + "mutability": "", + "signature": "expectCreate2(bytes,address)", + "selector": "0xea54a472", + "selectorBytes": [ + 234, + 84, + 164, + 114 + ] + }, + "group": "testing", + "status": "stable", + "safety": "unsafe" + }, { "func": { "id": "expectEmitAnonymous_0", diff --git a/crates/cheatcodes/spec/src/vm.rs b/crates/cheatcodes/spec/src/vm.rs index f2eaeb697a18..1891cc45f470 100644 --- a/crates/cheatcodes/spec/src/vm.rs +++ b/crates/cheatcodes/spec/src/vm.rs @@ -1040,6 +1040,14 @@ interface Vm { #[cheatcode(group = Testing, safety = Unsafe)] function expectEmitAnonymous(address emitter) external; + /// Expects the deployment of the specified bytecode by the specified address using the CREATE opcode + #[cheatcode(group = Testing, safety = Unsafe)] + function expectCreate(bytes calldata bytecode, address deployer) external; + + /// Expects the deployment of the specified bytecode by the specified address using the CREATE2 opcode + #[cheatcode(group = Testing, safety = Unsafe)] + function expectCreate2(bytes calldata bytecode, address deployer) external; + /// Expects an error on next call with any revert data. #[cheatcode(group = Testing, safety = Unsafe)] function expectRevert() external; diff --git a/crates/cheatcodes/src/inspector.rs b/crates/cheatcodes/src/inspector.rs index cf6248c8f02b..4c6a2202f954 100644 --- a/crates/cheatcodes/src/inspector.rs +++ b/crates/cheatcodes/src/inspector.rs @@ -12,8 +12,8 @@ use crate::{ test::{ assume::AssumeNoRevert, expect::{ - self, ExpectedCallData, ExpectedCallTracker, ExpectedCallType, ExpectedEmitTracker, - ExpectedRevert, ExpectedRevertKind, + self, ExpectedCallData, ExpectedCallTracker, ExpectedCallType, ExpectedCreate, + ExpectedEmitTracker, ExpectedRevert, ExpectedRevertKind, }, revert_handlers, }, @@ -431,6 +431,8 @@ pub struct Cheatcodes { pub expected_calls: ExpectedCallTracker, /// Expected emits pub expected_emits: ExpectedEmitTracker, + /// Expected creates + pub expected_creates: Vec, /// Map of context depths to memory offset ranges that may be written to within the call depth. pub allowed_mem_writes: HashMap>>, @@ -521,6 +523,7 @@ impl Cheatcodes { mocked_functions: Default::default(), expected_calls: Default::default(), expected_emits: Default::default(), + expected_creates: Default::default(), allowed_mem_writes: Default::default(), broadcast: Default::default(), broadcastable_transactions: Default::default(), @@ -723,7 +726,12 @@ impl Cheatcodes { } // common create_end functionality for both legacy and EOF. - fn create_end_common(&mut self, ecx: Ecx, mut outcome: CreateOutcome) -> CreateOutcome + fn create_end_common( + &mut self, + ecx: Ecx, + call: Option<&CreateInputs>, + mut outcome: CreateOutcome, + ) -> CreateOutcome where { let ecx = &mut ecx.inner; @@ -834,6 +842,26 @@ where { } } } + + // Match the create against expected_creates + if !self.expected_creates.is_empty() { + if let (Some(address), Some(call)) = (outcome.address, call) { + if let Ok(created_acc) = ecx.journaled_state.load_account(address, &mut ecx.db) { + let bytecode = + created_acc.info.code.clone().unwrap_or_default().original_bytes(); + if let Some((index, _)) = + self.expected_creates.iter().find_position(|expected_create| { + expected_create.deployer == call.caller && + expected_create.create_scheme.eq(call.scheme) && + expected_create.bytecode == bytecode + }) + { + self.expected_creates.swap_remove(index); + } + } + } + } + outcome } @@ -1565,7 +1593,9 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes { } // If there's not a revert, we can continue on to run the last logic for expect* - // cheatcodes. Match expected calls + // cheatcodes. + + // Match expected calls for (address, calldatas) in &self.expected_calls { // Loop over each address, and for each address, loop over each calldata it expects. for (calldata, (expected, actual_count)) in calldatas { @@ -1613,6 +1643,7 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes { } } } + // Check if we have any leftover expected emits // First, if any emits were found at the root call, then we its ok and we remove them. self.expected_emits.retain(|(expected, _)| expected.count > 0 && !expected.found); @@ -1629,6 +1660,19 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes { outcome.result.output = Error::encode(msg); return outcome; } + + // Check for leftover expected creates + if let Some(expected_create) = self.expected_creates.first() { + let msg = format!( + "expected {} call by address {} for bytecode {} but not found", + expected_create.create_scheme, + hex::encode_prefixed(expected_create.deployer), + hex::encode_prefixed(&expected_create.bytecode), + ); + outcome.result.result = InstructionResult::Revert; + outcome.result.output = Error::encode(msg); + return outcome; + } } outcome @@ -1641,10 +1685,10 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes { fn create_end( &mut self, ecx: Ecx, - _call: &CreateInputs, + call: &CreateInputs, outcome: CreateOutcome, ) -> CreateOutcome { - self.create_end_common(ecx, outcome) + self.create_end_common(ecx, Some(call), outcome) } fn eofcreate(&mut self, ecx: Ecx, call: &mut EOFCreateInputs) -> Option { @@ -1657,7 +1701,7 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes { _call: &EOFCreateInputs, outcome: CreateOutcome, ) -> CreateOutcome { - self.create_end_common(ecx, outcome) + self.create_end_common(ecx, None, outcome) } } diff --git a/crates/cheatcodes/src/test/expect.rs b/crates/cheatcodes/src/test/expect.rs index 11d0e65a0237..843ab468594d 100644 --- a/crates/cheatcodes/src/test/expect.rs +++ b/crates/cheatcodes/src/test/expect.rs @@ -1,4 +1,7 @@ -use std::collections::VecDeque; +use std::{ + collections::VecDeque, + fmt::{self, Display}, +}; use crate::{Cheatcode, Cheatcodes, CheatsCtxt, Error, Result, Vm::*}; use alloy_primitives::{ @@ -104,6 +107,41 @@ pub struct ExpectedEmit { pub count: u64, } +#[derive(Clone, Debug)] +pub struct ExpectedCreate { + /// The address that deployed the contract + pub deployer: Address, + /// Runtime bytecode of the contract + pub bytecode: Bytes, + /// Whether deployed with CREATE or CREATE2 + pub create_scheme: CreateScheme, +} + +#[derive(Clone, Debug)] +pub enum CreateScheme { + Create, + Create2, +} + +impl Display for CreateScheme { + fn fmt(&self, f: &mut fmt::Formatter) -> std::fmt::Result { + match self { + Self::Create => write!(f, "CREATE"), + Self::Create2 => write!(f, "CREATE2"), + } + } +} + +impl CreateScheme { + pub fn eq(&self, create_scheme: revm::primitives::CreateScheme) -> bool { + matches!( + (self, create_scheme), + (Self::Create, revm::primitives::CreateScheme::Create) | + (Self::Create2, revm::primitives::CreateScheme::Create2 { .. }) + ) + } +} + impl Cheatcode for expectCall_0Call { fn apply(&self, state: &mut Cheatcodes) -> Result { let Self { callee, data } = self; @@ -338,6 +376,20 @@ impl Cheatcode for expectEmitAnonymous_3Call { } } +impl Cheatcode for expectCreateCall { + fn apply(&self, state: &mut Cheatcodes) -> Result { + let Self { bytecode, deployer } = self; + expect_create(state, bytecode.clone(), *deployer, CreateScheme::Create) + } +} + +impl Cheatcode for expectCreate2Call { + fn apply(&self, state: &mut Cheatcodes) -> Result { + let Self { bytecode, deployer } = self; + expect_create(state, bytecode.clone(), *deployer, CreateScheme::Create2) + } +} + impl Cheatcode for expectRevert_0Call { fn apply_stateful(&self, ccx: &mut CheatsCtxt) -> Result { let Self {} = self; @@ -889,6 +941,18 @@ impl LogCountMap { } } +fn expect_create( + state: &mut Cheatcodes, + bytecode: Bytes, + deployer: Address, + create_scheme: CreateScheme, +) -> Result { + let expected_create = ExpectedCreate { bytecode, deployer, create_scheme }; + state.expected_creates.push(expected_create); + + Ok(Default::default()) +} + fn expect_revert( state: &mut Cheatcodes, reason: Option<&[u8]>, diff --git a/crates/forge/tests/cli/failure_assertions.rs b/crates/forge/tests/cli/failure_assertions.rs index 611d4d0bd221..3913a04bdbe1 100644 --- a/crates/forge/tests/cli/failure_assertions.rs +++ b/crates/forge/tests/cli/failure_assertions.rs @@ -156,6 +156,30 @@ Suite result: FAILED. 0 passed; 3 failed; 0 skipped; [ELAPSED] ); }); +forgetest!(expect_create_tests_should_fail, |prj, cmd| { + prj.insert_ds_test(); + prj.insert_vm(); + + let expect_create_failures = include_str!("../fixtures/ExpectCreateFailures.t.sol"); + + prj.add_source("ExpectCreateFailures.t.sol", expect_create_failures).unwrap(); + + cmd.forge_fuse().args(["test", "--mc", "ExpectCreateFailureTest"]).assert_failure().stdout_eq(str![[r#" +... +[FAIL: expected CREATE call by address 0x7fa9385be102ac3eac297483dd6233d62b3e1496 for bytecode [..] but not found] testShouldFailExpectCreate() ([GAS]) +[FAIL: expected CREATE2 call by address 0x7fa9385be102ac3eac297483dd6233d62b3e1496 for bytecode [..] but not found] testShouldFailExpectCreate2() ([GAS]) +[FAIL: expected CREATE2 call by address 0x7fa9385be102ac3eac297483dd6233d62b3e1496 for bytecode [..] but not found] testShouldFailExpectCreate2WrongBytecode() ([GAS]) +[FAIL: expected CREATE2 call by address 0x0000000000000000000000000000000000000000 for bytecode [..] but not found] testShouldFailExpectCreate2WrongDeployer() ([GAS]) +[FAIL: expected CREATE2 call by address 0x7fa9385be102ac3eac297483dd6233d62b3e1496 for bytecode [..] but not found] testShouldFailExpectCreate2WrongScheme() ([GAS]) +[FAIL: expected CREATE call by address 0x7fa9385be102ac3eac297483dd6233d62b3e1496 for bytecode [..] but not found] testShouldFailExpectCreateWrongBytecode() ([GAS]) +[FAIL: expected CREATE call by address 0x0000000000000000000000000000000000000000 for bytecode [..] but not found] testShouldFailExpectCreateWrongDeployer() ([GAS]) +[FAIL: expected CREATE call by address 0x7fa9385be102ac3eac297483dd6233d62b3e1496 for bytecode [..] but not found] testShouldFailExpectCreateWrongScheme() ([GAS]) +Suite result: FAILED. 0 passed; 8 failed; 0 skipped; [ELAPSED] +... + +"#]]); +}); + forgetest!(expect_emit_tests_should_fail, |prj, cmd| { prj.insert_ds_test(); prj.insert_vm(); diff --git a/crates/forge/tests/fixtures/ExpectCreateFailures.t.sol b/crates/forge/tests/fixtures/ExpectCreateFailures.t.sol new file mode 100644 index 000000000000..ebd5df73d50b --- /dev/null +++ b/crates/forge/tests/fixtures/ExpectCreateFailures.t.sol @@ -0,0 +1,62 @@ +// Note Used in forge-cli tests to assert failures. +// SPDX-License-Identifier: MIT OR Apache-2.0 +pragma solidity ^0.8.18; + +import "./test.sol"; +import "./Vm.sol"; + +contract Contract { + function add(uint256 a, uint256 b) public pure returns (uint256) { + return a + b; + } +} + +contract OtherContract { + function sub(uint256 a, uint256 b) public pure returns (uint256) { + return a - b; + } +} + +contract ExpectCreateFailureTest is DSTest { + Vm constant vm = Vm(HEVM_ADDRESS); + bytes contractBytecode = + vm.getDeployedCode("ExpectCreateFailures.t.sol:Contract"); + + function testShouldFailExpectCreate() public { + vm.expectCreate(contractBytecode, address(this)); + } + + function testShouldFailExpectCreate2() public { + vm.expectCreate2(contractBytecode, address(this)); + } + + function testShouldFailExpectCreateWrongBytecode() public { + vm.expectCreate(contractBytecode, address(this)); + new OtherContract(); + } + + function testShouldFailExpectCreate2WrongBytecode() public { + vm.expectCreate2(contractBytecode, address(this)); + new OtherContract{salt: "foobar"}(); + } + + function testShouldFailExpectCreateWrongDeployer() public { + vm.expectCreate(contractBytecode, address(0)); + new Contract(); + } + + function testShouldFailExpectCreate2WrongDeployer() public { + vm.expectCreate2(contractBytecode, address(0)); + new Contract(); + } + + function testShouldFailExpectCreateWrongScheme() public { + vm.expectCreate(contractBytecode, address(this)); + new Contract{salt: "foobar"}(); + } + + function testShouldFailExpectCreate2WrongScheme() public { + vm.expectCreate2(contractBytecode, address(this)); + new Contract(); + } +} diff --git a/testdata/cheats/Vm.sol b/testdata/cheats/Vm.sol index cc873d3e1be2..01c581366ffc 100644 --- a/testdata/cheats/Vm.sol +++ b/testdata/cheats/Vm.sol @@ -237,6 +237,8 @@ interface Vm { function expectCall(address callee, uint256 msgValue, bytes calldata data, uint64 count) external; function expectCall(address callee, uint256 msgValue, uint64 gas, bytes calldata data) external; function expectCall(address callee, uint256 msgValue, uint64 gas, bytes calldata data, uint64 count) external; + function expectCreate(bytes calldata bytecode, address deployer) external; + function expectCreate2(bytes calldata bytecode, address deployer) external; function expectEmitAnonymous(bool checkTopic0, bool checkTopic1, bool checkTopic2, bool checkTopic3, bool checkData) external; function expectEmitAnonymous(bool checkTopic0, bool checkTopic1, bool checkTopic2, bool checkTopic3, bool checkData, address emitter) external; function expectEmitAnonymous() external; diff --git a/testdata/default/cheats/ExpectCreate.t.sol b/testdata/default/cheats/ExpectCreate.t.sol new file mode 100644 index 000000000000..a922d01b92ba --- /dev/null +++ b/testdata/default/cheats/ExpectCreate.t.sol @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT OR Apache-2.0 +pragma solidity ^0.8.18; + +import "ds-test/test.sol"; +import "cheats/Vm.sol"; + +contract Contract { + function add(uint256 a, uint256 b) public pure returns (uint256) { + return a + b; + } +} + +contract ContractDeployer { + function deployContract() public { + new Contract(); + } + + function deployContractCreate2() public { + new Contract{salt: "foo"}(); + } +} + +contract ExpectCreateTest is DSTest { + Vm constant vm = Vm(HEVM_ADDRESS); + bytes bytecode = vm.getDeployedCode("cheats/ExpectCreate.t.sol:Contract"); + + function testExpectCreate() public { + vm.expectCreate(bytecode, address(this)); + new Contract(); + } + + function testExpectCreate2() public { + vm.expectCreate2(bytecode, address(this)); + new Contract{salt: "foo"}(); + } + + function testExpectNestedCreate() public { + ContractDeployer foo = new ContractDeployer(); + vm.expectCreate(bytecode, address(foo)); + vm.expectCreate2(bytecode, address(foo)); + foo.deployContract(); + foo.deployContractCreate2(); + } +}