From 0ff1663c2c2ec39200643a0b887081d5100d95af Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Fri, 31 Jan 2025 14:45:58 +0100 Subject: [PATCH 01/18] feat(snforge_std): add "dynamic" return value for mock_call --- snforge_std/src/cheatcodes.cairo | 58 +++++++++++++++++++++++++++----- snforge_std/src/lib.cairo | 1 + 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/snforge_std/src/cheatcodes.cairo b/snforge_std/src/cheatcodes.cairo index 1b3665bc9d..277df84b46 100644 --- a/snforge_std/src/cheatcodes.cairo +++ b/snforge_std/src/cheatcodes.cairo @@ -20,6 +20,36 @@ pub enum CheatSpan { TargetCalls: usize, } +/// Enum used to specify the call data that should be matched when mocking a contract call. +#[derive(Copy, Drop, PartialEq, Clone, Debug)] +pub enum MockCallData { + /// Matches any call data. + Any, + /// Matches the specified serialized call data. + Values: Span, +} + +impl MockCallDataSerde of Serde { + fn deserialize(ref serialized: Span) -> Option { + let value: Option>> = Serde::deserialize(ref serialized); + + match value { + Option::None => Option::None, + Option::Some(call_data) => match call_data { + Option::None => Option::Some(MockCallData::Any), + Option::Some(data) => Option::Some(MockCallData::Values(data)), + }, + } + } + + fn serialize(self: @MockCallData, ref output: Array) { + match self { + MockCallData::Any => Option::>::None.serialize(ref output), + MockCallData::Values(data) => Option::Some(*data).serialize(ref output), + } + } +} + pub fn test_selector() -> felt252 { // Result of selector!("TEST_CONTRACT_SELECTOR") since `selector!` macro requires dependency on // `starknet`. @@ -43,13 +73,17 @@ pub fn test_address() -> ContractAddress { /// - `ret_data` - data to return by the function `function_selector` /// - `n_times` - number of calls to mock the function for pub fn mock_call, impl TDestruct: Destruct>( - contract_address: ContractAddress, function_selector: felt252, ret_data: T, n_times: u32 + contract_address: ContractAddress, + function_selector: felt252, + call_data: MockCallData, + ret_data: T, + n_times: u32 ) { assert!(n_times > 0, "cannot mock_call 0 times, n_times argument must be greater than 0"); let contract_address_felt: felt252 = contract_address.into(); let mut inputs = array![contract_address_felt, function_selector]; - + call_data.serialize(ref inputs); CheatSpan::TargetCalls(n_times).serialize(ref inputs); let mut ret_data_arr = ArrayTrait::new(); @@ -66,13 +100,17 @@ pub fn mock_call, impl TDestruct: Destruct /// - `contract_address` - targeted contracts' address /// - `function_selector` - hashed name of the target function (can be obtained with `selector!` /// macro) +/// - `call_data` - matching call data /// - `ret_data` - data to be returned by the function pub fn start_mock_call, impl TDestruct: Destruct>( - contract_address: ContractAddress, function_selector: felt252, ret_data: T + contract_address: ContractAddress, + function_selector: felt252, + call_data: MockCallData, + ret_data: T ) { let contract_address_felt: felt252 = contract_address.into(); let mut inputs = array![contract_address_felt, function_selector]; - + call_data.serialize(ref inputs); CheatSpan::Indefinite.serialize(ref inputs); let mut ret_data_arr = ArrayTrait::new(); @@ -87,12 +125,16 @@ pub fn start_mock_call, impl TDestruct: De /// address. /// - `contract_address` - targeted contracts' address /// - `function_selector` - hashed name of the target function (can be obtained with `selector!` +/// - `call_data` - matching call data /// macro) -pub fn stop_mock_call(contract_address: ContractAddress, function_selector: felt252) { +pub fn stop_mock_call( + contract_address: ContractAddress, function_selector: felt252, call_data: MockCallData +) { let contract_address_felt: felt252 = contract_address.into(); - execute_cheatcode_and_deserialize::< - 'stop_mock_call', () - >(array![contract_address_felt, function_selector].span()); + let mut inputs = array![contract_address_felt, function_selector]; + call_data.serialize(ref inputs); + + execute_cheatcode_and_deserialize::<'stop_mock_call', ()>(inputs.span()); } #[derive(Drop, Serde, PartialEq, Debug)] diff --git a/snforge_std/src/lib.cairo b/snforge_std/src/lib.cairo index e62e9e6006..321369fd06 100644 --- a/snforge_std/src/lib.cairo +++ b/snforge_std/src/lib.cairo @@ -30,6 +30,7 @@ pub use cheatcodes::CheatSpan; pub use cheatcodes::ReplaceBytecodeError; pub use cheatcodes::test_address; pub use cheatcodes::test_selector; +pub use cheatcodes::MockCallData; pub use cheatcodes::mock_call; pub use cheatcodes::start_mock_call; pub use cheatcodes::stop_mock_call; From f21e51056d72ff15bf8f3c4fc500b04a460dc8c6 Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Fri, 31 Jan 2025 14:50:19 +0100 Subject: [PATCH 02/18] feat(cheatnet): add "dynamic" return value fir mock_call --- Cargo.lock | 1 + crates/cheatnet/Cargo.toml | 1 + .../execution/entry_point.rs | 19 ++++++++++++--- .../cheatcodes/mock_call.rs | 23 ++++++++++++++----- .../forge_runtime_extension/mod.rs | 15 ++++++++---- crates/cheatnet/src/state.rs | 2 +- crates/cheatnet/tests/cheatcodes/mock_call.rs | 17 ++++++++++++-- 7 files changed, 61 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c6bfaed0d9..59dbd0b23e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1285,6 +1285,7 @@ dependencies = [ "serde_json", "shared", "starknet", + "starknet-crypto 0.7.1", "starknet-types-core", "starknet_api", "tempfile", diff --git a/crates/cheatnet/Cargo.toml b/crates/cheatnet/Cargo.toml index 66c651b585..52663d9bd0 100644 --- a/crates/cheatnet/Cargo.toml +++ b/crates/cheatnet/Cargo.toml @@ -13,6 +13,7 @@ bimap.workspace = true camino.workspace = true starknet_api.workspace = true starknet-types-core.workspace = true +starknet-crypto.workspace = true cairo-lang-casm.workspace = true cairo-lang-runner.workspace = true cairo-lang-utils.workspace = true diff --git a/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs b/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs index 96a30cab2c..9b62aeacc6 100644 --- a/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs +++ b/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs @@ -18,11 +18,13 @@ use blockifier::{ state::state_api::State, }; use cairo_vm::vm::runners::cairo_runner::{CairoRunner, ExecutionResources}; +use num_traits::Zero; use starknet_api::{ core::ClassHash, deprecated_contract_class::EntryPointType, transaction::{Calldata, TransactionVersion}, }; +use starknet_crypto::poseidon_hash_many; use std::collections::HashSet; use std::rc::Rc; use blockifier::execution::deprecated_syscalls::hint_processor::SyscallCounter; @@ -268,11 +270,22 @@ fn get_mocked_function_cheat_status<'a>( if call.call_type == CallType::Delegate { return None; } - - cheatnet_state + match cheatnet_state .mocked_functions .get_mut(&call.storage_address) - .and_then(|contract_functions| contract_functions.get_mut(&call.entry_point_selector)) + { + None => None, + Some(contract_functions) => { + let call_data_hash = poseidon_hash_many(call.calldata.0.iter()); + let key = (call.entry_point_selector, call_data_hash); + if contract_functions.contains_key(&key) { + contract_functions.get_mut(&key) + } else { + let key_zero = (call.entry_point_selector, Felt::zero()); + contract_functions.get_mut(&key_zero) + } + } + } } fn mocked_call_info(call: CallEntryPoint, ret_data: Vec) -> CallInfo { diff --git a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs index 5aeb0aff77..81a4e1c92e 100644 --- a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs +++ b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs @@ -1,6 +1,8 @@ use crate::state::{CheatSpan, CheatStatus}; use crate::CheatnetState; +use num_traits::Zero; use starknet_api::core::{ContractAddress, EntryPointSelector}; +use starknet_crypto::poseidon_hash_many; use starknet_types_core::felt::Felt; use std::collections::hash_map::Entry; @@ -9,26 +11,30 @@ impl CheatnetState { &mut self, contract_address: ContractAddress, function_selector: EntryPointSelector, + call_data: Option>, ret_data: &[Felt], span: CheatSpan, ) { let contract_mocked_functions = self.mocked_functions.entry(contract_address).or_default(); - - contract_mocked_functions.insert( - function_selector, - CheatStatus::Cheated(ret_data.to_vec(), span), - ); + let call_data_hash = match call_data { + Some(data) => poseidon_hash_many(data.iter()), + None => Felt::zero(), + }; + let key = (function_selector, call_data_hash); + contract_mocked_functions.insert(key, CheatStatus::Cheated(ret_data.to_vec(), span)); } pub fn start_mock_call( &mut self, contract_address: ContractAddress, function_selector: EntryPointSelector, + call_data: Option>, ret_data: &[Felt], ) { self.mock_call( contract_address, function_selector, + call_data, ret_data, CheatSpan::Indefinite, ); @@ -38,10 +44,15 @@ impl CheatnetState { &mut self, contract_address: ContractAddress, function_selector: EntryPointSelector, + call_data: Option>, ) { if let Entry::Occupied(mut e) = self.mocked_functions.entry(contract_address) { let contract_mocked_functions = e.get_mut(); - contract_mocked_functions.remove(&function_selector); + let call_data_hash = match call_data { + Some(data) => poseidon_hash_many(data.iter()), + None => Felt::zero(), + }; + contract_mocked_functions.remove(&(function_selector, call_data_hash)); } } } diff --git a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs index 2938c7bf97..c7b76bbb7e 100644 --- a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs +++ b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs @@ -87,26 +87,31 @@ impl<'a> ExtensionLogic for ForgeExtension<'a> { "mock_call" => { let contract_address = input_reader.read()?; let function_selector = input_reader.read()?; + let call_data = input_reader.read()?; let span = input_reader.read()?; - let ret_data: Vec<_> = input_reader.read()?; - extended_runtime .extended_runtime .extension .cheatnet_state - .mock_call(contract_address, function_selector, &ret_data, span); + .mock_call( + contract_address, + function_selector, + call_data, + &ret_data, + span, + ); Ok(CheatcodeHandlingResult::from_serializable(())) } "stop_mock_call" => { let contract_address = input_reader.read()?; let function_selector = input_reader.read()?; - + let call_data = input_reader.read()?; extended_runtime .extended_runtime .extension .cheatnet_state - .stop_mock_call(contract_address, function_selector); + .stop_mock_call(contract_address, function_selector, call_data); Ok(CheatcodeHandlingResult::from_serializable(())) } "replace_bytecode" => { diff --git a/crates/cheatnet/src/state.rs b/crates/cheatnet/src/state.rs index 5b0fbf7189..8905868b4f 100644 --- a/crates/cheatnet/src/state.rs +++ b/crates/cheatnet/src/state.rs @@ -331,7 +331,7 @@ pub struct CheatnetState { pub global_cheated_execution_info: ExecutionInfoMock, pub mocked_functions: - HashMap>>>, + HashMap>>>, pub replaced_bytecode_contracts: HashMap, pub detected_events: Vec, pub detected_messages_to_l1: Vec, diff --git a/crates/cheatnet/tests/cheatcodes/mock_call.rs b/crates/cheatnet/tests/cheatcodes/mock_call.rs index e7cf8653ac..4b4bd844f0 100644 --- a/crates/cheatnet/tests/cheatcodes/mock_call.rs +++ b/crates/cheatnet/tests/cheatcodes/mock_call.rs @@ -68,6 +68,7 @@ fn mock_call_simple() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), + None, &ret_data, ); @@ -100,6 +101,7 @@ fn mock_call_stop() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), + None, &ret_data, ); @@ -168,10 +170,10 @@ fn mock_call_double() { let selector = felt_selector_from_name("get_thing"); let ret_data = [Felt::from(123)]; - cheatnet_state.start_mock_call(contract_address, selector, &ret_data); + cheatnet_state.start_mock_call(contract_address, selector, None, &ret_data); let ret_data = [Felt::from(999)]; - cheatnet_state.start_mock_call(contract_address, selector, &ret_data); + cheatnet_state.start_mock_call(contract_address, selector, None, &ret_data); let output = call_contract( &mut cached_state, @@ -214,6 +216,7 @@ fn mock_call_double_call() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), + None, &ret_data, ); @@ -255,6 +258,7 @@ fn mock_call_proxy() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), + None, &ret_data, ); @@ -303,6 +307,7 @@ fn mock_call_proxy_with_other_syscall() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), + None, &ret_data, ); @@ -352,6 +357,7 @@ fn mock_call_inner_call_no_effect() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), + None, &ret_data, ); @@ -407,6 +413,7 @@ fn mock_call_library_call_no_effect() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_constant_thing"), + None, &ret_data, ); @@ -440,6 +447,7 @@ fn mock_call_before_deployment() { cheatnet_state.start_mock_call( precalculated_address, felt_selector_from_name("get_thing"), + None, &ret_data, ); @@ -482,6 +490,7 @@ fn mock_call_not_implemented() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing_not_implemented"), + None, &ret_data, ); @@ -512,6 +521,7 @@ fn mock_call_in_constructor() { cheatnet_state.start_mock_call( balance_contract_address, felt_selector_from_name("get_balance"), + None, &ret_data, ); @@ -559,12 +569,14 @@ fn mock_call_two_methods() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), + None, &ret_data, ); cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_constant_thing"), + None, &ret_data, ); @@ -602,6 +614,7 @@ fn mock_call_nonexisting_contract() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), + None, &ret_data, ); From 3c378144497f27754afa52aaf11ac9e2b56560d1 Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Fri, 31 Jan 2025 14:50:47 +0100 Subject: [PATCH 03/18] forge: update test to support new mock_call cheatcode --- crates/forge/tests/integration/cheat_fork.rs | 6 ++-- crates/forge/tests/integration/mock_call.rs | 32 ++++++++++++-------- crates/forge/tests/integration/test_state.rs | 4 +-- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/crates/forge/tests/integration/cheat_fork.rs b/crates/forge/tests/integration/cheat_fork.rs index a1ba19c3de..6e0ad3d679 100644 --- a/crates/forge/tests/integration/cheat_fork.rs +++ b/crates/forge/tests/integration/cheat_fork.rs @@ -155,7 +155,7 @@ fn mock_call_cairo0_contract() { let test = test_case!(formatdoc!( r#" use starknet::{{contract_address_const}}; - use snforge_std::{{start_mock_call, stop_mock_call}}; + use snforge_std::{{start_mock_call, stop_mock_call, MockCallData}}; #[starknet::interface] trait IERC20 {{ @@ -173,11 +173,11 @@ fn mock_call_cairo0_contract() { assert(eth_dispatcher.name() == 'Ether', 'invalid name'); - start_mock_call(eth_dispatcher.contract_address, selector!("name"), 'NotEther'); + start_mock_call(eth_dispatcher.contract_address, selector!("name"), MockCallData::Any, 'NotEther'); assert(eth_dispatcher.name() == 'NotEther', 'invalid mocked name'); - stop_mock_call(eth_dispatcher.contract_address, selector!("name")); + stop_mock_call(eth_dispatcher.contract_address, selector!("name"), MockCallData::Any); assert(eth_dispatcher.name() == 'Ether', 'invalid name after mock'); }} diff --git a/crates/forge/tests/integration/mock_call.rs b/crates/forge/tests/integration/mock_call.rs index 6205e110fc..57e287294b 100644 --- a/crates/forge/tests/integration/mock_call.rs +++ b/crates/forge/tests/integration/mock_call.rs @@ -10,7 +10,7 @@ fn mock_call_simple() { indoc!( r#" use result::ResultTrait; - use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call, stop_mock_call }; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call, stop_mock_call, MockCallData }; #[starknet::interface] trait IMockChecker { @@ -26,13 +26,19 @@ fn mock_call_simple() { let dispatcher = IMockCheckerDispatcher { contract_address }; - let mock_ret_data = 421; + let specific_mock_ret_data = 421; + let default_mock_ret_data = 404; + let expected_calldata = MockCallData::Values([].span()); + start_mock_call(contract_address, selector!("get_thing"), expected_calldata, specific_mock_ret_data); + start_mock_call(contract_address, selector!("get_thing"), MockCallData::Any, default_mock_ret_data); + let thing = dispatcher.get_thing(); + assert(thing == specific_mock_ret_data, 'Incorrect thing'); - start_mock_call(contract_address, selector!("get_thing"), mock_ret_data); + stop_mock_call(contract_address, selector!("get_thing"), expected_calldata); let thing = dispatcher.get_thing(); - assert(thing == 421, 'Incorrect thing'); + assert(thing == default_mock_ret_data, 'Incorrect thing'); - stop_mock_call(contract_address, selector!("get_thing")); + stop_mock_call(contract_address, selector!("get_thing"), MockCallData::Any); let thing = dispatcher.get_thing(); assert(thing == 420, 'Incorrect thing'); } @@ -45,12 +51,12 @@ fn mock_call_simple() { let (contract_address, _) = contract.deploy(@calldata).unwrap(); let mock_ret_data = 421; - start_mock_call(contract_address, selector!("get_thing"), mock_ret_data); + start_mock_call(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data); let dispatcher = IMockCheckerDispatcher { contract_address }; let thing = dispatcher.get_thing(); - assert(thing == 421, 'Incorrect thing'); + assert(thing == 421, 'Incorrect thing all catch'); } "# ), @@ -73,7 +79,7 @@ fn mock_call_complex_types() { use result::ResultTrait; use array::ArrayTrait; use serde::Serde; - use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call }; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call, MockCallData }; #[starknet::interface] trait IMockChecker { @@ -97,7 +103,7 @@ fn mock_call_complex_types() { let dispatcher = IMockCheckerDispatcher { contract_address }; let mock_ret_data = StructThing {item_one: 412, item_two: 421}; - start_mock_call(contract_address, selector!("get_struct_thing"), mock_ret_data); + start_mock_call(contract_address, selector!("get_struct_thing"), MockCallData::Any, mock_ret_data); let thing: StructThing = dispatcher.get_struct_thing(); @@ -115,7 +121,7 @@ fn mock_call_complex_types() { let dispatcher = IMockCheckerDispatcher { contract_address }; let mock_ret_data = array![ StructThing {item_one: 112, item_two: 121}, StructThing {item_one: 412, item_two: 421} ]; - start_mock_call(contract_address, selector!("get_arr_thing"), mock_ret_data); + start_mock_call(contract_address, selector!("get_arr_thing"), MockCallData::Any, mock_ret_data); let things: Array = dispatcher.get_arr_thing(); @@ -146,7 +152,7 @@ fn mock_calls() { indoc!( r#" use result::ResultTrait; - use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call, start_mock_call, stop_mock_call }; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call, MockCallData }; #[starknet::interface] trait IMockChecker { @@ -164,7 +170,7 @@ fn mock_calls() { let mock_ret_data = 421; - mock_call(contract_address, selector!("get_thing"), mock_ret_data, 1); + mock_call(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 1); let thing = dispatcher.get_thing(); assert_eq!(thing, 421); @@ -184,7 +190,7 @@ fn mock_calls() { let mock_ret_data = 421; - mock_call(contract_address, selector!("get_thing"), mock_ret_data, 2); + mock_call(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 2); let thing = dispatcher.get_thing(); assert_eq!(thing, 421); diff --git a/crates/forge/tests/integration/test_state.rs b/crates/forge/tests/integration/test_state.rs index 00a8d7478b..4f4b08656a 100644 --- a/crates/forge/tests/integration/test_state.rs +++ b/crates/forge/tests/integration/test_state.rs @@ -678,7 +678,7 @@ fn inconsistent_syscall_pointers() { r#" use starknet::ContractAddress; use starknet::info::get_block_number; - use snforge_std::start_mock_call; + use snforge_std::{start_mock_call, MockCallData}; #[starknet::interface] trait IContract { @@ -689,7 +689,7 @@ fn inconsistent_syscall_pointers() { fn inconsistent_syscall_pointers() { // verifies if SyscallHandler.syscal_ptr is incremented correctly when calling a contract let address = 'address'.try_into().unwrap(); - start_mock_call(address, selector!("get_value"), 55); + start_mock_call(address, selector!("get_value"), MockCallData::Any, 55); let contract = IContractDispatcher { contract_address: address }; contract.get_value(address); get_block_number(); From d1921eaf93811d07913a213bbcdeff07d54b5fee Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Fri, 31 Jan 2025 16:07:07 +0100 Subject: [PATCH 04/18] fix mock_call test --- crates/cheatnet/tests/cheatcodes/mock_call.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/crates/cheatnet/tests/cheatcodes/mock_call.rs b/crates/cheatnet/tests/cheatcodes/mock_call.rs index 4b4bd844f0..4e57fa0a66 100644 --- a/crates/cheatnet/tests/cheatcodes/mock_call.rs +++ b/crates/cheatnet/tests/cheatcodes/mock_call.rs @@ -38,6 +38,7 @@ impl MockCallTrait for TestEnvironment { self.cheatnet_state.mock_call( *contract_address, function_selector.into_(), + None, &ret_data, span, ); @@ -46,7 +47,7 @@ impl MockCallTrait for TestEnvironment { fn stop_mock_call(&mut self, contract_address: &ContractAddress, function_name: &str) { let function_selector = get_selector_from_name(function_name).unwrap(); self.cheatnet_state - .stop_mock_call(*contract_address, function_selector.into_()); + .stop_mock_call(*contract_address, function_selector.into_(), None); } } @@ -115,7 +116,7 @@ fn mock_call_stop() { assert_success(output, &ret_data); - cheatnet_state.stop_mock_call(contract_address, felt_selector_from_name("get_thing")); + cheatnet_state.stop_mock_call(contract_address, felt_selector_from_name("get_thing"), None); let output = call_contract( &mut cached_state, @@ -142,7 +143,7 @@ fn mock_call_stop_no_start() { let selector = felt_selector_from_name("get_thing"); - cheatnet_state.stop_mock_call(contract_address, felt_selector_from_name("get_thing")); + cheatnet_state.stop_mock_call(contract_address, felt_selector_from_name("get_thing"), None); let output = call_contract( &mut cached_state, @@ -185,7 +186,7 @@ fn mock_call_double() { assert_success(output, &ret_data); - cheatnet_state.stop_mock_call(contract_address, selector); + cheatnet_state.stop_mock_call(contract_address, selector, None); let output = call_contract( &mut cached_state, From 9594b5aeff10cb41999e9298e0c370a4dffbfa53 Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Fri, 31 Jan 2025 16:28:15 +0100 Subject: [PATCH 05/18] cheatnet: fixup clippy --- crates/cheatnet/src/state.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/cheatnet/src/state.rs b/crates/cheatnet/src/state.rs index 8905868b4f..7dfb5d00f9 100644 --- a/crates/cheatnet/src/state.rs +++ b/crates/cheatnet/src/state.rs @@ -326,12 +326,13 @@ pub struct EncounteredError { pub class_hash: ClassHash, } +type MockedFunctionKey = (EntryPointSelector, Felt); pub struct CheatnetState { pub cheated_execution_info_contracts: HashMap, pub global_cheated_execution_info: ExecutionInfoMock, pub mocked_functions: - HashMap>>>, + HashMap>>>, pub replaced_bytecode_contracts: HashMap, pub detected_events: Vec, pub detected_messages_to_l1: Vec, From c097e3c63a53034c2ee6920780d0ef31955b428f Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Tue, 4 Feb 2025 00:03:26 +0100 Subject: [PATCH 06/18] snforge_std: revert breaking changes on `mock_call` Add `mock_call_when`, `start_mock_call_when` and `stop_mock_call_when` --- crates/forge/tests/integration/cheat_fork.rs | 6 +- crates/forge/tests/integration/mock_call.rs | 213 +++++++++++++++++-- crates/forge/tests/integration/test_state.rs | 4 +- snforge_std/src/cheatcodes.cairo | 49 ++++- snforge_std/src/lib.cairo | 3 + 5 files changed, 247 insertions(+), 28 deletions(-) diff --git a/crates/forge/tests/integration/cheat_fork.rs b/crates/forge/tests/integration/cheat_fork.rs index 6e0ad3d679..a1ba19c3de 100644 --- a/crates/forge/tests/integration/cheat_fork.rs +++ b/crates/forge/tests/integration/cheat_fork.rs @@ -155,7 +155,7 @@ fn mock_call_cairo0_contract() { let test = test_case!(formatdoc!( r#" use starknet::{{contract_address_const}}; - use snforge_std::{{start_mock_call, stop_mock_call, MockCallData}}; + use snforge_std::{{start_mock_call, stop_mock_call}}; #[starknet::interface] trait IERC20 {{ @@ -173,11 +173,11 @@ fn mock_call_cairo0_contract() { assert(eth_dispatcher.name() == 'Ether', 'invalid name'); - start_mock_call(eth_dispatcher.contract_address, selector!("name"), MockCallData::Any, 'NotEther'); + start_mock_call(eth_dispatcher.contract_address, selector!("name"), 'NotEther'); assert(eth_dispatcher.name() == 'NotEther', 'invalid mocked name'); - stop_mock_call(eth_dispatcher.contract_address, selector!("name"), MockCallData::Any); + stop_mock_call(eth_dispatcher.contract_address, selector!("name")); assert(eth_dispatcher.name() == 'Ether', 'invalid name after mock'); }} diff --git a/crates/forge/tests/integration/mock_call.rs b/crates/forge/tests/integration/mock_call.rs index 57e287294b..956be65f03 100644 --- a/crates/forge/tests/integration/mock_call.rs +++ b/crates/forge/tests/integration/mock_call.rs @@ -10,7 +10,7 @@ fn mock_call_simple() { indoc!( r#" use result::ResultTrait; - use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call, stop_mock_call, MockCallData }; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call, stop_mock_call }; #[starknet::interface] trait IMockChecker { @@ -26,19 +26,13 @@ fn mock_call_simple() { let dispatcher = IMockCheckerDispatcher { contract_address }; - let specific_mock_ret_data = 421; - let default_mock_ret_data = 404; - let expected_calldata = MockCallData::Values([].span()); - start_mock_call(contract_address, selector!("get_thing"), expected_calldata, specific_mock_ret_data); - start_mock_call(contract_address, selector!("get_thing"), MockCallData::Any, default_mock_ret_data); - let thing = dispatcher.get_thing(); - assert(thing == specific_mock_ret_data, 'Incorrect thing'); + let mock_ret_data = 421; - stop_mock_call(contract_address, selector!("get_thing"), expected_calldata); + start_mock_call(contract_address, selector!("get_thing"), mock_ret_data); let thing = dispatcher.get_thing(); - assert(thing == default_mock_ret_data, 'Incorrect thing'); + assert(thing == 421, 'Incorrect thing'); - stop_mock_call(contract_address, selector!("get_thing"), MockCallData::Any); + stop_mock_call(contract_address, selector!("get_thing")); let thing = dispatcher.get_thing(); assert(thing == 420, 'Incorrect thing'); } @@ -51,12 +45,12 @@ fn mock_call_simple() { let (contract_address, _) = contract.deploy(@calldata).unwrap(); let mock_ret_data = 421; - start_mock_call(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data); + start_mock_call(contract_address, selector!("get_thing"), mock_ret_data); let dispatcher = IMockCheckerDispatcher { contract_address }; let thing = dispatcher.get_thing(); - assert(thing == 421, 'Incorrect thing all catch'); + assert(thing == 421, 'Incorrect thing'); } "# ), @@ -79,7 +73,7 @@ fn mock_call_complex_types() { use result::ResultTrait; use array::ArrayTrait; use serde::Serde; - use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call, MockCallData }; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call }; #[starknet::interface] trait IMockChecker { @@ -103,7 +97,7 @@ fn mock_call_complex_types() { let dispatcher = IMockCheckerDispatcher { contract_address }; let mock_ret_data = StructThing {item_one: 412, item_two: 421}; - start_mock_call(contract_address, selector!("get_struct_thing"), MockCallData::Any, mock_ret_data); + start_mock_call(contract_address, selector!("get_struct_thing"), mock_ret_data); let thing: StructThing = dispatcher.get_struct_thing(); @@ -121,7 +115,7 @@ fn mock_call_complex_types() { let dispatcher = IMockCheckerDispatcher { contract_address }; let mock_ret_data = array![ StructThing {item_one: 112, item_two: 121}, StructThing {item_one: 412, item_two: 421} ]; - start_mock_call(contract_address, selector!("get_arr_thing"), MockCallData::Any, mock_ret_data); + start_mock_call(contract_address, selector!("get_arr_thing"), mock_ret_data); let things: Array = dispatcher.get_arr_thing(); @@ -152,7 +146,7 @@ fn mock_calls() { indoc!( r#" use result::ResultTrait; - use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call, MockCallData }; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call, start_mock_call, stop_mock_call }; #[starknet::interface] trait IMockChecker { @@ -170,7 +164,7 @@ fn mock_calls() { let mock_ret_data = 421; - mock_call(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 1); + mock_call(contract_address, selector!("get_thing"), mock_ret_data, 1); let thing = dispatcher.get_thing(); assert_eq!(thing, 421); @@ -190,7 +184,7 @@ fn mock_calls() { let mock_ret_data = 421; - mock_call(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 2); + mock_call(contract_address, selector!("get_thing"), mock_ret_data, 2); let thing = dispatcher.get_thing(); assert_eq!(thing, 421); @@ -213,3 +207,184 @@ fn mock_calls() { let result = run_test_case(&test); assert_passed(&result); } + +#[test] +fn mock_call_when_simple() { + let test = test_case!( + indoc!( + r#" + use result::ResultTrait; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call_when, stop_mock_call_when, MockCallData }; + + #[starknet::interface] + trait IMockChecker { + fn get_thing(ref self: TContractState) -> felt252; + } + + #[test] + fn mock_call_when_simple() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let specific_mock_ret_data = 421; + let default_mock_ret_data = 404; + let expected_calldata = MockCallData::Values([].span()); + + start_mock_call_when(contract_address, selector!("get_thing"), expected_calldata, specific_mock_ret_data); + start_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, default_mock_ret_data); + let thing = dispatcher.get_thing(); + assert(thing == specific_mock_ret_data, 'Incorrect thing'); + + stop_mock_call_when(contract_address, selector!("get_thing"), expected_calldata); + let thing = dispatcher.get_thing(); + assert(thing == default_mock_ret_data, 'Incorrect thing'); + + stop_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any); + let thing = dispatcher.get_thing(); + assert(thing == 420, 'Incorrect thing'); + } + + #[test] + fn mock_call_when_simple_before_dispatcher_created() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let specific_mock_ret_data = 421; + let default_mock_ret_data = 404; + let expected_calldata = MockCallData::Values([].span()); + + start_mock_call_when(contract_address, selector!("get_thing"), expected_calldata, specific_mock_ret_data); + start_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, default_mock_ret_data); + let dispatcher = IMockCheckerDispatcher { contract_address }; + let thing = dispatcher.get_thing(); + + assert(thing == specific_mock_ret_data, 'Incorrect thing'); + + stop_mock_call_when(contract_address, selector!("get_thing"), expected_calldata); + let thing = dispatcher.get_thing(); + assert(thing == default_mock_ret_data, 'Incorrect thing'); + + stop_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any); + let thing = dispatcher.get_thing(); + assert(thing == 420, 'Incorrect thing'); + } + "# + ), + Contract::from_code_path( + "MockChecker".to_string(), + Path::new("tests/data/contracts/mock_checker.cairo"), + ) + .unwrap() + ); + + let result = run_test_case(&test); + assert_passed(&result); +} + +#[test] +fn mock_call_when_complex_types() { + let test = test_case!( + indoc!( + r#" + use result::ResultTrait; + use array::ArrayTrait; + use serde::Serde; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call_when, stop_mock_call_when, MockCallData }; + + #[starknet::interface] + trait IMockChecker { + fn get_struct_thing(ref self: TContractState) -> StructThing; + fn get_arr_thing(ref self: TContractState) -> Array; + } + + #[derive(Serde, Drop)] + struct StructThing { + item_one: felt252, + item_two: felt252, + } + + #[test] + fn start_mock_call_when_return_struct() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let default_mock_ret_data = StructThing {item_one: 412, item_two: 421}; + let specific_mock_ret_data = StructThing {item_one: 404, item_two: 401}; + let expected_calldata = MockCallData::Values([].span()); + + start_mock_call_when(contract_address, selector!("get_struct_thing"), MockCallData::Any, default_mock_ret_data); + start_mock_call_when(contract_address, selector!("get_struct_thing"), expected_calldata, specific_mock_ret_data); + + let thing: StructThing = dispatcher.get_struct_thing(); + + assert(thing.item_one == 404, 'thing.item_one'); + assert(thing.item_two == 401, 'thing.item_two'); + + stop_mock_call_when(contract_address, selector!("get_struct_thing"), expected_calldata); + let thing: StructThing = dispatcher.get_struct_thing(); + + assert(thing.item_one == 412, 'thing.item_one'); + assert(thing.item_two == 421, 'thing.item_two'); + } + + #[test] + fn start_mock_call_when_return_arr() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let default_mock_ret_data = array![ StructThing {item_one: 112, item_two: 121}, StructThing {item_one: 412, item_two: 421} ]; + let specific_mock_ret_data = array![ StructThing {item_one: 212, item_two: 221}, StructThing {item_one: 512, item_two: 521} ]; + + let expected_calldata = MockCallData::Values([].span()); + + start_mock_call_when(contract_address, selector!("get_arr_thing"), MockCallData::Any, default_mock_ret_data); + start_mock_call_when(contract_address, selector!("get_arr_thing"), expected_calldata, specific_mock_ret_data); + + let things: Array = dispatcher.get_arr_thing(); + + let thing = things.at(0); + assert(*thing.item_one == 212, 'thing1.item_one 1'); + assert(*thing.item_two == 221, 'thing1.item_two'); + + let thing = things.at(1); + assert(*thing.item_one == 512, 'thing2.item_one 2'); + assert(*thing.item_two == 521, 'thing2.item_two'); + + stop_mock_call_when(contract_address, selector!("get_arr_thing"), expected_calldata); + + let things: Array = dispatcher.get_arr_thing(); + + let thing = things.at(0); + assert(*thing.item_one == 112, 'thing1.item_one 3'); + assert(*thing.item_two == 121, 'thing1.item_two'); + + let thing = things.at(1); + assert(*thing.item_one == 412, 'thing2.item_one 4'); + assert(*thing.item_two == 421, 'thing2.item_two'); + } + "# + ), + Contract::from_code_path( + "MockChecker".to_string(), + Path::new("tests/data/contracts/mock_checker.cairo"), + ) + .unwrap() + ); + + let result = run_test_case(&test); + assert_passed(&result); +} diff --git a/crates/forge/tests/integration/test_state.rs b/crates/forge/tests/integration/test_state.rs index 4f4b08656a..00a8d7478b 100644 --- a/crates/forge/tests/integration/test_state.rs +++ b/crates/forge/tests/integration/test_state.rs @@ -678,7 +678,7 @@ fn inconsistent_syscall_pointers() { r#" use starknet::ContractAddress; use starknet::info::get_block_number; - use snforge_std::{start_mock_call, MockCallData}; + use snforge_std::start_mock_call; #[starknet::interface] trait IContract { @@ -689,7 +689,7 @@ fn inconsistent_syscall_pointers() { fn inconsistent_syscall_pointers() { // verifies if SyscallHandler.syscal_ptr is incremented correctly when calling a contract let address = 'address'.try_into().unwrap(); - start_mock_call(address, selector!("get_value"), MockCallData::Any, 55); + start_mock_call(address, selector!("get_value"), 55); let contract = IContractDispatcher { contract_address: address }; contract.get_value(address); get_block_number(); diff --git a/snforge_std/src/cheatcodes.cairo b/snforge_std/src/cheatcodes.cairo index 277df84b46..d2ed207e41 100644 --- a/snforge_std/src/cheatcodes.cairo +++ b/snforge_std/src/cheatcodes.cairo @@ -60,6 +60,7 @@ pub fn test_address() -> ContractAddress { contract_address_const::<469394814521890341860918960550914>() } + /// Mocks contract call to a `function_selector` of a contract at the given address, for `n_times` /// first calls that are made to the contract. /// A call to function `function_selector` will return data provided in `ret_data` argument. @@ -73,6 +74,46 @@ pub fn test_address() -> ContractAddress { /// - `ret_data` - data to return by the function `function_selector` /// - `n_times` - number of calls to mock the function for pub fn mock_call, impl TDestruct: Destruct>( + contract_address: ContractAddress, function_selector: felt252, ret_data: T, n_times: u32 +) { + mock_call_when(contract_address, function_selector, MockCallData::Any, ret_data, n_times) +} + +/// Mocks contract call to a function of a contract at the given address, indefinitely. +/// See `mock_call` for comprehensive definition of how it can be used. +/// - `contract_address` - targeted contracts' address +/// - `function_selector` - hashed name of the target function (can be obtained with `selector!` +/// macro) +/// - `ret_data` - data to be returned by the function +pub fn start_mock_call, impl TDestruct: Destruct>( + contract_address: ContractAddress, function_selector: felt252, ret_data: T +) { + start_mock_call_when(contract_address, function_selector, MockCallData::Any, ret_data) +} + +/// Cancels the `mock_call` / `start_mock_call` for the function with given name and contract +/// address. +/// - `contract_address` - targeted contracts' address +/// - `function_selector` - hashed name of the target function (can be obtained with `selector!` +/// macro) +pub fn stop_mock_call(contract_address: ContractAddress, function_selector: felt252,) { + stop_mock_call_when(contract_address, function_selector, MockCallData::Any) +} + +/// Mocks contract call to a `function_selector` of a contract at the given address, for `n_times` +/// first calls that are made to the contract. +/// A call to function `function_selector` will return data provided in `ret_data` argument. +/// An address with no contract can be mocked as well. +/// An entrypoint that is not present on the deployed contract is also possible to mock. +/// Note that the function is not meant for mocking internal calls - it works only for contract +/// entry points. +/// - `contract_address` - target contract address +/// - `function_selector` - hashed name of the target function (can be obtained with `selector!` +/// macro) +/// - `call_data` - matching call data +/// - `ret_data` - data to return by the function `function_selector` +/// - `n_times` - number of calls to mock the function for +pub fn mock_call_when, impl TDestruct: Destruct>( contract_address: ContractAddress, function_selector: felt252, call_data: MockCallData, @@ -102,7 +143,7 @@ pub fn mock_call, impl TDestruct: Destruct /// macro) /// - `call_data` - matching call data /// - `ret_data` - data to be returned by the function -pub fn start_mock_call, impl TDestruct: Destruct>( +pub fn start_mock_call_when, impl TDestruct: Destruct>( contract_address: ContractAddress, function_selector: felt252, call_data: MockCallData, @@ -121,13 +162,13 @@ pub fn start_mock_call, impl TDestruct: De execute_cheatcode_and_deserialize::<'mock_call', ()>(inputs.span()); } -/// Cancels the `mock_call` / `start_mock_call` for the function with given name and contract -/// address. +/// Cancels the `mock_call_when` / `start_mock_call_when` for the function with given name and +/// contract address. /// - `contract_address` - targeted contracts' address /// - `function_selector` - hashed name of the target function (can be obtained with `selector!` /// - `call_data` - matching call data /// macro) -pub fn stop_mock_call( +pub fn stop_mock_call_when( contract_address: ContractAddress, function_selector: felt252, call_data: MockCallData ) { let contract_address_felt: felt252 = contract_address.into(); diff --git a/snforge_std/src/lib.cairo b/snforge_std/src/lib.cairo index 321369fd06..ab2532bcab 100644 --- a/snforge_std/src/lib.cairo +++ b/snforge_std/src/lib.cairo @@ -34,6 +34,9 @@ pub use cheatcodes::MockCallData; pub use cheatcodes::mock_call; pub use cheatcodes::start_mock_call; pub use cheatcodes::stop_mock_call; +pub use cheatcodes::mock_call_when; +pub use cheatcodes::start_mock_call_when; +pub use cheatcodes::stop_mock_call_when; pub use cheatcodes::replace_bytecode; pub use cheatcodes::execution_info::caller_address::cheat_caller_address; From e70b4b19930b2c6e7f7a1830868483556002a5d4 Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Tue, 4 Feb 2025 00:56:23 +0100 Subject: [PATCH 07/18] fix(cheatnet): default to any calldata entry when specific CheatSpan::TargetCalls is 0 --- .../execution/entry_point.rs | 15 +- crates/forge/tests/integration/mock_call.rs | 192 ++++++++++++++++++ 2 files changed, 201 insertions(+), 6 deletions(-) diff --git a/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs b/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs index 9b62aeacc6..e207a750e5 100644 --- a/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs +++ b/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs @@ -2,7 +2,7 @@ use std::cell::RefCell; use super::cairo1_execution::execute_entry_point_call_cairo1; use crate::runtime_extensions::call_to_blockifier_runtime_extension::execution::deprecated::cairo0_execution::execute_entry_point_call_cairo0; use crate::runtime_extensions::call_to_blockifier_runtime_extension::CheatnetState; -use crate::state::{CallTrace, CallTraceNode, CheatStatus, EncounteredError}; +use crate::state::{CallTrace, CallTraceNode, CheatSpan, CheatStatus, EncounteredError}; use blockifier::execution::call_info::{CallExecution, Retdata}; use blockifier::{ execution::{ @@ -278,11 +278,14 @@ fn get_mocked_function_cheat_status<'a>( Some(contract_functions) => { let call_data_hash = poseidon_hash_many(call.calldata.0.iter()); let key = (call.entry_point_selector, call_data_hash); - if contract_functions.contains_key(&key) { - contract_functions.get_mut(&key) - } else { - let key_zero = (call.entry_point_selector, Felt::zero()); - contract_functions.get_mut(&key_zero) + let key_zero = (call.entry_point_selector, Felt::zero()); + + match contract_functions.get(&key) { + Some(CheatStatus::Cheated(_, CheatSpan::TargetCalls(0))) => { + contract_functions.get_mut(&key_zero) + } + Some(CheatStatus::Cheated(_, _)) => contract_functions.get_mut(&key), + _ => contract_functions.get_mut(&key_zero), } } } diff --git a/crates/forge/tests/integration/mock_call.rs b/crates/forge/tests/integration/mock_call.rs index 956be65f03..6346f3a089 100644 --- a/crates/forge/tests/integration/mock_call.rs +++ b/crates/forge/tests/integration/mock_call.rs @@ -388,3 +388,195 @@ fn mock_call_when_complex_types() { let result = run_test_case(&test); assert_passed(&result); } + +#[test] +fn mock_calls_when() { + let test = test_case!( + indoc!( + r#" + use result::ResultTrait; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCallData}; + + #[starknet::interface] + trait IMockChecker { + fn get_thing(ref self: TContractState) -> felt252; + } + + #[test] + fn mock_call_when_one_specific() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + let expected_calldata = MockCallData::Values([].span()); + mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 1); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + + #[test] + fn mock_call_when_twice_specific() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + let expected_calldata = MockCallData::Values([].span()); + mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 2); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + + #[test] + fn mock_call_when_one_any() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 1); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + + #[test] + fn mock_call_when_twice_any() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 2); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + + "# + ), + Contract::from_code_path( + "MockChecker".to_string(), + Path::new("tests/data/contracts/mock_checker.cairo"), + ) + .unwrap() + ); + + let result = run_test_case(&test); + assert_passed(&result); +} + +#[test] +fn mock_calls_when_mixed() { + let test = test_case!( + indoc!( + r#" + use result::ResultTrait; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCallData}; + + #[starknet::interface] + trait IMockChecker { + fn get_thing(ref self: TContractState) -> felt252; + } + + #[test] + fn mock_call_when_one() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + let expected_calldata = MockCallData::Values([].span()); + mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 1); + mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, 422, 1); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421, "Specific calldata"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 422, "Any calldata"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + + #[test] + fn mock_call_when_multi() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + let expected_calldata = MockCallData::Values([].span()); + mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 3); + mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, 422, 2); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421, "1st Specific calldata"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421, "2nd Specific calldata"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 421, "3rd Specific calldata"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 422, "1st Any calldata"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 422, "2nd Any calldata"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + "# + ), + Contract::from_code_path( + "MockChecker".to_string(), + Path::new("tests/data/contracts/mock_checker.cairo"), + ) + .unwrap() + ); + + let result = run_test_case(&test); + assert_passed(&result); +} From 02a7dcd5965af6ff3cf3324dcb499480888d53ab Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Tue, 4 Feb 2025 01:16:57 +0100 Subject: [PATCH 08/18] docs: add `mock_call_when` cheatcode --- CHANGELOG.md | 1 + docs/src/SUMMARY.md | 1 + docs/src/appendix/cheatcodes.md | 3 ++ .../src/appendix/cheatcodes/mock_call_when.md | 43 +++++++++++++++++++ 4 files changed, 48 insertions(+) create mode 100644 docs/src/appendix/cheatcodes/mock_call_when.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 00e0703606..17614110f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Rust is no longer required to use `snforge` if using Scarb >= 2.10.0 on supported platforms - precompiled `snforge_scarb_plugin` plugin binaries are now published to [package registry](https://scarbs.xyz) for new versions. - Added a suggestion for using the `--max-n-steps` flag when the Cairo VM returns the error: `Could not reach the end of the program. RunResources has no remaining steps`. +- `mock_call_when`, `start_mock_call_when`, `stop_mock_call_when` cheatcodes. #### Fixed diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index a047619daf..23744fd2fc 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -93,6 +93,7 @@ * [fee_data_availability_mode](appendix/cheatcodes/fee_data_availability_mode.md) * [account_deployment_data](appendix/cheatcodes/account_deployment_data.md) * [mock_call](appendix/cheatcodes/mock_call.md) + * [mock_call_when](appendix/cheatcodes/mock_call_when.md) * [get_class_hash](appendix/cheatcodes/get_class_hash.md) * [replace_bytecode](appendix/cheatcodes/replace_bytecode.md) * [l1_handler](appendix/cheatcodes/l1_handler.md) diff --git a/docs/src/appendix/cheatcodes.md b/docs/src/appendix/cheatcodes.md index 67c16f6cfe..410df62885 100644 --- a/docs/src/appendix/cheatcodes.md +++ b/docs/src/appendix/cheatcodes.md @@ -4,6 +4,9 @@ - [`mock_call`](cheatcodes/mock_call.md#mock_call) - mocks a number of contract calls to an entry point - [`start_mock_call`](cheatcodes/mock_call.md#start_mock_call) - mocks contract call to an entry point - [`stop_mock_call`](cheatcodes/mock_call.md#stop_mock_call) - cancels the `mock_call` / `start_mock_call` for an entry point +- [`mock_call_when`](cheatcodes/mock_call_when.md#mock_call_when) - mocks a number of contract calls to an entry point for a given call data +- [`start_mock_call_when`](cheatcodes/mock_call_when.md#start_mock_call_when) - mocks contract call to an entry point for a given call data +- [`stop_mock_call_when`](cheatcodes/mock_call_when.md#stop_mock_call_when) - cancels the `mock_call_when` / `start_mock_call_when` for an entry point - [`get_class_hash`](cheatcodes/get_class_hash.md) - retrieves a class hash of a contract - [`replace_bytecode`](cheatcodes/replace_bytecode.md) - replace the class hash of a contract - [`l1_handler`](cheatcodes/l1_handler.md) - executes a `#[l1_handler]` function to mock a message arriving from Ethereum diff --git a/docs/src/appendix/cheatcodes/mock_call_when.md b/docs/src/appendix/cheatcodes/mock_call_when.md new file mode 100644 index 0000000000..e718a77446 --- /dev/null +++ b/docs/src/appendix/cheatcodes/mock_call_when.md @@ -0,0 +1,43 @@ +# `mock_call_when` + +Cheatcodes mocking contract entry point calls: + +## `MockCallData` + +```rust +pub enum MockCallData { + Any, + Values: Span, +} +``` + +`MockCallData` is an enum used to specify for which call data the contract entry point will be mocked. +- `Any` mock the contract entry point for any call data. +- `Values` mock the contract entry point only for this call data. + +## `mock_call_when` +> `fn mock_call_when, impl TDestruct: Destruct>( +> contract_address: ContractAddress, function_selector: felt252, call_data: MockCallData, ret_data: T, n_times: u32 +> )` + +Mocks contract call to a `function_selector` of a contract at the given address, with the given call data, for `n_times` first calls that are made +to the contract. +A call to function `function_selector` will return data provided in `ret_data` argument. +An address with no contract can be mocked as well. +An entrypoint that is not present on the deployed contract is also possible to mock. +Note that the function is not meant for mocking internal calls - it works only for contract entry points. + +## `start_mock_call_when` +> `fn start_mock_call, impl TDestruct: Destruct>( +> contract_address: ContractAddress, function_selector: felt252, call_data: MockCallData, ret_data: T +> )` + +Mocks contract call to a `function_selector` of a contract at the given address, with the given call data, indefinitely. +See `mock_call_when` for comprehensive definition of how it can be used. + + +### `stop_mock_call_when` + +> `fn stop_mock_call_when(contract_address: ContractAddress, function_selector: felt252, call_data: MockCallData)` + +Cancels the `mock_call_when` / `start_mock_call_when` for the function `function_selector` of a contract at the given addressn with the given call data From 349870106f54b13020077cbad97ca341d2f3b7e6 Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Wed, 12 Feb 2025 00:06:09 +0100 Subject: [PATCH 09/18] use `calldata` instead of `call_data` --- .../execution/entry_point.rs | 4 ++-- .../cheatcodes/mock_call.rs | 16 ++++++++-------- docs/src/appendix/cheatcodes.md | 4 ++-- docs/src/appendix/cheatcodes/mock_call_when.md | 18 +++++++++--------- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs b/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs index e207a750e5..fd33c1b56a 100644 --- a/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs +++ b/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs @@ -276,8 +276,8 @@ fn get_mocked_function_cheat_status<'a>( { None => None, Some(contract_functions) => { - let call_data_hash = poseidon_hash_many(call.calldata.0.iter()); - let key = (call.entry_point_selector, call_data_hash); + let calldata_hash = poseidon_hash_many(call.calldata.0.iter()); + let key = (call.entry_point_selector, calldata_hash); let key_zero = (call.entry_point_selector, Felt::zero()); match contract_functions.get(&key) { diff --git a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs index 81a4e1c92e..9e4216f65e 100644 --- a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs +++ b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs @@ -11,16 +11,16 @@ impl CheatnetState { &mut self, contract_address: ContractAddress, function_selector: EntryPointSelector, - call_data: Option>, + calldata: Option>, ret_data: &[Felt], span: CheatSpan, ) { let contract_mocked_functions = self.mocked_functions.entry(contract_address).or_default(); - let call_data_hash = match call_data { + let calldata_hash = match calldata { Some(data) => poseidon_hash_many(data.iter()), None => Felt::zero(), }; - let key = (function_selector, call_data_hash); + let key = (function_selector, calldata_hash); contract_mocked_functions.insert(key, CheatStatus::Cheated(ret_data.to_vec(), span)); } @@ -28,13 +28,13 @@ impl CheatnetState { &mut self, contract_address: ContractAddress, function_selector: EntryPointSelector, - call_data: Option>, + calldata: Option>, ret_data: &[Felt], ) { self.mock_call( contract_address, function_selector, - call_data, + calldata, ret_data, CheatSpan::Indefinite, ); @@ -44,15 +44,15 @@ impl CheatnetState { &mut self, contract_address: ContractAddress, function_selector: EntryPointSelector, - call_data: Option>, + calldata: Option>, ) { if let Entry::Occupied(mut e) = self.mocked_functions.entry(contract_address) { let contract_mocked_functions = e.get_mut(); - let call_data_hash = match call_data { + let calldata_hash = match calldata { Some(data) => poseidon_hash_many(data.iter()), None => Felt::zero(), }; - contract_mocked_functions.remove(&(function_selector, call_data_hash)); + contract_mocked_functions.remove(&(function_selector, calldata_hash)); } } } diff --git a/docs/src/appendix/cheatcodes.md b/docs/src/appendix/cheatcodes.md index 410df62885..2f5544d000 100644 --- a/docs/src/appendix/cheatcodes.md +++ b/docs/src/appendix/cheatcodes.md @@ -4,8 +4,8 @@ - [`mock_call`](cheatcodes/mock_call.md#mock_call) - mocks a number of contract calls to an entry point - [`start_mock_call`](cheatcodes/mock_call.md#start_mock_call) - mocks contract call to an entry point - [`stop_mock_call`](cheatcodes/mock_call.md#stop_mock_call) - cancels the `mock_call` / `start_mock_call` for an entry point -- [`mock_call_when`](cheatcodes/mock_call_when.md#mock_call_when) - mocks a number of contract calls to an entry point for a given call data -- [`start_mock_call_when`](cheatcodes/mock_call_when.md#start_mock_call_when) - mocks contract call to an entry point for a given call data +- [`mock_call_when`](cheatcodes/mock_call_when.md#mock_call_when) - mocks a number of contract calls to an entry point for a given calldata +- [`start_mock_call_when`](cheatcodes/mock_call_when.md#start_mock_call_when) - mocks contract call to an entry point for a given calldata - [`stop_mock_call_when`](cheatcodes/mock_call_when.md#stop_mock_call_when) - cancels the `mock_call_when` / `start_mock_call_when` for an entry point - [`get_class_hash`](cheatcodes/get_class_hash.md) - retrieves a class hash of a contract - [`replace_bytecode`](cheatcodes/replace_bytecode.md) - replace the class hash of a contract diff --git a/docs/src/appendix/cheatcodes/mock_call_when.md b/docs/src/appendix/cheatcodes/mock_call_when.md index e718a77446..0b13676b02 100644 --- a/docs/src/appendix/cheatcodes/mock_call_when.md +++ b/docs/src/appendix/cheatcodes/mock_call_when.md @@ -11,16 +11,16 @@ pub enum MockCallData { } ``` -`MockCallData` is an enum used to specify for which call data the contract entry point will be mocked. -- `Any` mock the contract entry point for any call data. -- `Values` mock the contract entry point only for this call data. +`MockCallData` is an enum used to specify for which calldata the contract entry point will be mocked. +- `Any` mock the contract entry point for any calldata. +- `Values` mock the contract entry point only for this calldata. ## `mock_call_when` > `fn mock_call_when, impl TDestruct: Destruct>( -> contract_address: ContractAddress, function_selector: felt252, call_data: MockCallData, ret_data: T, n_times: u32 +> contract_address: ContractAddress, function_selector: felt252, calldata: MockCallData, ret_data: T, n_times: u32 > )` -Mocks contract call to a `function_selector` of a contract at the given address, with the given call data, for `n_times` first calls that are made +Mocks contract call to a `function_selector` of a contract at the given address, with the given calldata, for `n_times` first calls that are made to the contract. A call to function `function_selector` will return data provided in `ret_data` argument. An address with no contract can be mocked as well. @@ -29,15 +29,15 @@ Note that the function is not meant for mocking internal calls - it works only f ## `start_mock_call_when` > `fn start_mock_call, impl TDestruct: Destruct>( -> contract_address: ContractAddress, function_selector: felt252, call_data: MockCallData, ret_data: T +> contract_address: ContractAddress, function_selector: felt252, calldata: MockCallData, ret_data: T > )` -Mocks contract call to a `function_selector` of a contract at the given address, with the given call data, indefinitely. +Mocks contract call to a `function_selector` of a contract at the given address, with the given calldata, indefinitely. See `mock_call_when` for comprehensive definition of how it can be used. ### `stop_mock_call_when` -> `fn stop_mock_call_when(contract_address: ContractAddress, function_selector: felt252, call_data: MockCallData)` +> `fn stop_mock_call_when(contract_address: ContractAddress, function_selector: felt252, calldata: MockCallData)` -Cancels the `mock_call_when` / `start_mock_call_when` for the function `function_selector` of a contract at the given addressn with the given call data +Cancels the `mock_call_when` / `start_mock_call_when` for the function `function_selector` of a contract at the given addressn with the given calldata From f4bc28229bd63b4cf4f8ce71f4c8068e92b4dfbd Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Tue, 18 Feb 2025 03:59:47 +0100 Subject: [PATCH 10/18] cheatnet: add MockCalldata enum - rename cairo MockCallData enum into MockCalldata - use Serde derivation for MockCalldata instead of custom serializer. --- .../cheatcodes/mock_call.rs | 16 +++--- .../forge_runtime_extension/mod.rs | 8 +-- crates/cheatnet/src/state.rs | 6 ++ crates/cheatnet/tests/cheatcodes/mock_call.rs | 55 +++++++++++-------- crates/forge/tests/integration/mock_call.rs | 44 +++++++-------- .../src/appendix/cheatcodes/mock_call_when.md | 12 ++-- snforge_std/src/cheatcodes.cairo | 54 ++++++------------ snforge_std/src/lib.cairo | 2 +- 8 files changed, 97 insertions(+), 100 deletions(-) diff --git a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs index 9e4216f65e..52ad0a7582 100644 --- a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs +++ b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/cheatcodes/mock_call.rs @@ -1,4 +1,4 @@ -use crate::state::{CheatSpan, CheatStatus}; +use crate::state::{CheatSpan, CheatStatus, MockCalldata}; use crate::CheatnetState; use num_traits::Zero; use starknet_api::core::{ContractAddress, EntryPointSelector}; @@ -11,14 +11,14 @@ impl CheatnetState { &mut self, contract_address: ContractAddress, function_selector: EntryPointSelector, - calldata: Option>, + calldata: MockCalldata, ret_data: &[Felt], span: CheatSpan, ) { let contract_mocked_functions = self.mocked_functions.entry(contract_address).or_default(); let calldata_hash = match calldata { - Some(data) => poseidon_hash_many(data.iter()), - None => Felt::zero(), + MockCalldata::Values(data) => poseidon_hash_many(data.iter()), + MockCalldata::Any => Felt::zero(), }; let key = (function_selector, calldata_hash); contract_mocked_functions.insert(key, CheatStatus::Cheated(ret_data.to_vec(), span)); @@ -28,7 +28,7 @@ impl CheatnetState { &mut self, contract_address: ContractAddress, function_selector: EntryPointSelector, - calldata: Option>, + calldata: MockCalldata, ret_data: &[Felt], ) { self.mock_call( @@ -44,13 +44,13 @@ impl CheatnetState { &mut self, contract_address: ContractAddress, function_selector: EntryPointSelector, - calldata: Option>, + calldata: MockCalldata, ) { if let Entry::Occupied(mut e) = self.mocked_functions.entry(contract_address) { let contract_mocked_functions = e.get_mut(); let calldata_hash = match calldata { - Some(data) => poseidon_hash_many(data.iter()), - None => Felt::zero(), + MockCalldata::Values(data) => poseidon_hash_many(data.iter()), + MockCalldata::Any => Felt::zero(), }; contract_mocked_functions.remove(&(function_selector, calldata_hash)); } diff --git a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs index c7b76bbb7e..299109715e 100644 --- a/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs +++ b/crates/cheatnet/src/runtime_extensions/forge_runtime_extension/mod.rs @@ -87,7 +87,7 @@ impl<'a> ExtensionLogic for ForgeExtension<'a> { "mock_call" => { let contract_address = input_reader.read()?; let function_selector = input_reader.read()?; - let call_data = input_reader.read()?; + let calldata = input_reader.read()?; let span = input_reader.read()?; let ret_data: Vec<_> = input_reader.read()?; extended_runtime @@ -97,7 +97,7 @@ impl<'a> ExtensionLogic for ForgeExtension<'a> { .mock_call( contract_address, function_selector, - call_data, + calldata, &ret_data, span, ); @@ -106,12 +106,12 @@ impl<'a> ExtensionLogic for ForgeExtension<'a> { "stop_mock_call" => { let contract_address = input_reader.read()?; let function_selector = input_reader.read()?; - let call_data = input_reader.read()?; + let calldata = input_reader.read()?; extended_runtime .extended_runtime .extension .cheatnet_state - .stop_mock_call(contract_address, function_selector, call_data); + .stop_mock_call(contract_address, function_selector, calldata); Ok(CheatcodeHandlingResult::from_serializable(())) } "replace_bytecode" => { diff --git a/crates/cheatnet/src/state.rs b/crates/cheatnet/src/state.rs index 7dfb5d00f9..c7cb1c750a 100644 --- a/crates/cheatnet/src/state.rs +++ b/crates/cheatnet/src/state.rs @@ -41,6 +41,12 @@ pub enum CheatSpan { TargetCalls(usize), } +#[derive(CairoDeserialize, Clone, Debug, PartialEq, Eq)] +pub enum MockCalldata { + Any, + Values(Vec), +} + #[derive(Debug)] pub struct ExtendedStateReader { pub dict_state_reader: DictStateReader, diff --git a/crates/cheatnet/tests/cheatcodes/mock_call.rs b/crates/cheatnet/tests/cheatcodes/mock_call.rs index 4e57fa0a66..e9cd0b3b81 100644 --- a/crates/cheatnet/tests/cheatcodes/mock_call.rs +++ b/crates/cheatnet/tests/cheatcodes/mock_call.rs @@ -8,7 +8,7 @@ use crate::{ common::{deploy_contract, get_contracts}, }; use cheatnet::runtime_extensions::forge_runtime_extension::cheatcodes::declare::declare; -use cheatnet::state::{CheatSpan, CheatnetState}; +use cheatnet::state::{CheatSpan, CheatnetState, MockCalldata}; use conversions::IntoConv; use starknet::core::utils::get_selector_from_name; use starknet_api::core::ContractAddress; @@ -38,7 +38,7 @@ impl MockCallTrait for TestEnvironment { self.cheatnet_state.mock_call( *contract_address, function_selector.into_(), - None, + MockCalldata::Any, &ret_data, span, ); @@ -46,8 +46,11 @@ impl MockCallTrait for TestEnvironment { fn stop_mock_call(&mut self, contract_address: &ContractAddress, function_name: &str) { let function_selector = get_selector_from_name(function_name).unwrap(); - self.cheatnet_state - .stop_mock_call(*contract_address, function_selector.into_(), None); + self.cheatnet_state.stop_mock_call( + *contract_address, + function_selector.into_(), + MockCalldata::Any, + ); } } @@ -69,7 +72,7 @@ fn mock_call_simple() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), - None, + MockCalldata::Any, &ret_data, ); @@ -102,7 +105,7 @@ fn mock_call_stop() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), - None, + MockCalldata::Any, &ret_data, ); @@ -116,7 +119,11 @@ fn mock_call_stop() { assert_success(output, &ret_data); - cheatnet_state.stop_mock_call(contract_address, felt_selector_from_name("get_thing"), None); + cheatnet_state.stop_mock_call( + contract_address, + felt_selector_from_name("get_thing"), + MockCalldata::Any, + ); let output = call_contract( &mut cached_state, @@ -143,7 +150,11 @@ fn mock_call_stop_no_start() { let selector = felt_selector_from_name("get_thing"); - cheatnet_state.stop_mock_call(contract_address, felt_selector_from_name("get_thing"), None); + cheatnet_state.stop_mock_call( + contract_address, + felt_selector_from_name("get_thing"), + MockCalldata::Any, + ); let output = call_contract( &mut cached_state, @@ -171,10 +182,10 @@ fn mock_call_double() { let selector = felt_selector_from_name("get_thing"); let ret_data = [Felt::from(123)]; - cheatnet_state.start_mock_call(contract_address, selector, None, &ret_data); + cheatnet_state.start_mock_call(contract_address, selector, MockCalldata::Any, &ret_data); let ret_data = [Felt::from(999)]; - cheatnet_state.start_mock_call(contract_address, selector, None, &ret_data); + cheatnet_state.start_mock_call(contract_address, selector, MockCalldata::Any, &ret_data); let output = call_contract( &mut cached_state, @@ -186,7 +197,7 @@ fn mock_call_double() { assert_success(output, &ret_data); - cheatnet_state.stop_mock_call(contract_address, selector, None); + cheatnet_state.stop_mock_call(contract_address, selector, MockCalldata::Any); let output = call_contract( &mut cached_state, @@ -217,7 +228,7 @@ fn mock_call_double_call() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), - None, + MockCalldata::Any, &ret_data, ); @@ -259,7 +270,7 @@ fn mock_call_proxy() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), - None, + MockCalldata::Any, &ret_data, ); @@ -308,7 +319,7 @@ fn mock_call_proxy_with_other_syscall() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), - None, + MockCalldata::Any, &ret_data, ); @@ -358,7 +369,7 @@ fn mock_call_inner_call_no_effect() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), - None, + MockCalldata::Any, &ret_data, ); @@ -414,7 +425,7 @@ fn mock_call_library_call_no_effect() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_constant_thing"), - None, + MockCalldata::Any, &ret_data, ); @@ -448,7 +459,7 @@ fn mock_call_before_deployment() { cheatnet_state.start_mock_call( precalculated_address, felt_selector_from_name("get_thing"), - None, + MockCalldata::Any, &ret_data, ); @@ -491,7 +502,7 @@ fn mock_call_not_implemented() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing_not_implemented"), - None, + MockCalldata::Any, &ret_data, ); @@ -522,7 +533,7 @@ fn mock_call_in_constructor() { cheatnet_state.start_mock_call( balance_contract_address, felt_selector_from_name("get_balance"), - None, + MockCalldata::Any, &ret_data, ); @@ -570,14 +581,14 @@ fn mock_call_two_methods() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), - None, + MockCalldata::Any, &ret_data, ); cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_constant_thing"), - None, + MockCalldata::Any, &ret_data, ); @@ -615,7 +626,7 @@ fn mock_call_nonexisting_contract() { cheatnet_state.start_mock_call( contract_address, felt_selector_from_name("get_thing"), - None, + MockCalldata::Any, &ret_data, ); diff --git a/crates/forge/tests/integration/mock_call.rs b/crates/forge/tests/integration/mock_call.rs index 6346f3a089..a57712a277 100644 --- a/crates/forge/tests/integration/mock_call.rs +++ b/crates/forge/tests/integration/mock_call.rs @@ -214,7 +214,7 @@ fn mock_call_when_simple() { indoc!( r#" use result::ResultTrait; - use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call_when, stop_mock_call_when, MockCallData }; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call_when, stop_mock_call_when, MockCalldata }; #[starknet::interface] trait IMockChecker { @@ -232,10 +232,10 @@ fn mock_call_when_simple() { let specific_mock_ret_data = 421; let default_mock_ret_data = 404; - let expected_calldata = MockCallData::Values([].span()); + let expected_calldata = MockCalldata::Values([].span()); start_mock_call_when(contract_address, selector!("get_thing"), expected_calldata, specific_mock_ret_data); - start_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, default_mock_ret_data); + start_mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any, default_mock_ret_data); let thing = dispatcher.get_thing(); assert(thing == specific_mock_ret_data, 'Incorrect thing'); @@ -243,7 +243,7 @@ fn mock_call_when_simple() { let thing = dispatcher.get_thing(); assert(thing == default_mock_ret_data, 'Incorrect thing'); - stop_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any); + stop_mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any); let thing = dispatcher.get_thing(); assert(thing == 420, 'Incorrect thing'); } @@ -257,10 +257,10 @@ fn mock_call_when_simple() { let specific_mock_ret_data = 421; let default_mock_ret_data = 404; - let expected_calldata = MockCallData::Values([].span()); + let expected_calldata = MockCalldata::Values([].span()); start_mock_call_when(contract_address, selector!("get_thing"), expected_calldata, specific_mock_ret_data); - start_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, default_mock_ret_data); + start_mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any, default_mock_ret_data); let dispatcher = IMockCheckerDispatcher { contract_address }; let thing = dispatcher.get_thing(); @@ -270,7 +270,7 @@ fn mock_call_when_simple() { let thing = dispatcher.get_thing(); assert(thing == default_mock_ret_data, 'Incorrect thing'); - stop_mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any); + stop_mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any); let thing = dispatcher.get_thing(); assert(thing == 420, 'Incorrect thing'); } @@ -295,7 +295,7 @@ fn mock_call_when_complex_types() { use result::ResultTrait; use array::ArrayTrait; use serde::Serde; - use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call_when, stop_mock_call_when, MockCallData }; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, start_mock_call_when, stop_mock_call_when, MockCalldata }; #[starknet::interface] trait IMockChecker { @@ -320,9 +320,9 @@ fn mock_call_when_complex_types() { let default_mock_ret_data = StructThing {item_one: 412, item_two: 421}; let specific_mock_ret_data = StructThing {item_one: 404, item_two: 401}; - let expected_calldata = MockCallData::Values([].span()); + let expected_calldata = MockCalldata::Values([].span()); - start_mock_call_when(contract_address, selector!("get_struct_thing"), MockCallData::Any, default_mock_ret_data); + start_mock_call_when(contract_address, selector!("get_struct_thing"), MockCalldata::Any, default_mock_ret_data); start_mock_call_when(contract_address, selector!("get_struct_thing"), expected_calldata, specific_mock_ret_data); let thing: StructThing = dispatcher.get_struct_thing(); @@ -349,9 +349,9 @@ fn mock_call_when_complex_types() { let default_mock_ret_data = array![ StructThing {item_one: 112, item_two: 121}, StructThing {item_one: 412, item_two: 421} ]; let specific_mock_ret_data = array![ StructThing {item_one: 212, item_two: 221}, StructThing {item_one: 512, item_two: 521} ]; - let expected_calldata = MockCallData::Values([].span()); + let expected_calldata = MockCalldata::Values([].span()); - start_mock_call_when(contract_address, selector!("get_arr_thing"), MockCallData::Any, default_mock_ret_data); + start_mock_call_when(contract_address, selector!("get_arr_thing"), MockCalldata::Any, default_mock_ret_data); start_mock_call_when(contract_address, selector!("get_arr_thing"), expected_calldata, specific_mock_ret_data); let things: Array = dispatcher.get_arr_thing(); @@ -395,7 +395,7 @@ fn mock_calls_when() { indoc!( r#" use result::ResultTrait; - use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCallData}; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCalldata}; #[starknet::interface] trait IMockChecker { @@ -412,7 +412,7 @@ fn mock_calls_when() { let dispatcher = IMockCheckerDispatcher { contract_address }; let mock_ret_data = 421; - let expected_calldata = MockCallData::Values([].span()); + let expected_calldata = MockCalldata::Values([].span()); mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 1); let thing = dispatcher.get_thing(); @@ -432,7 +432,7 @@ fn mock_calls_when() { let dispatcher = IMockCheckerDispatcher { contract_address }; let mock_ret_data = 421; - let expected_calldata = MockCallData::Values([].span()); + let expected_calldata = MockCalldata::Values([].span()); mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 2); let thing = dispatcher.get_thing(); @@ -455,7 +455,7 @@ fn mock_calls_when() { let dispatcher = IMockCheckerDispatcher { contract_address }; let mock_ret_data = 421; - mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 1); + mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any, mock_ret_data, 1); let thing = dispatcher.get_thing(); assert_eq!(thing, 421); @@ -474,7 +474,7 @@ fn mock_calls_when() { let dispatcher = IMockCheckerDispatcher { contract_address }; let mock_ret_data = 421; - mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, mock_ret_data, 2); + mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any, mock_ret_data, 2); let thing = dispatcher.get_thing(); assert_eq!(thing, 421); @@ -505,7 +505,7 @@ fn mock_calls_when_mixed() { indoc!( r#" use result::ResultTrait; - use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCallData}; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCalldata}; #[starknet::interface] trait IMockChecker { @@ -522,9 +522,9 @@ fn mock_calls_when_mixed() { let dispatcher = IMockCheckerDispatcher { contract_address }; let mock_ret_data = 421; - let expected_calldata = MockCallData::Values([].span()); + let expected_calldata = MockCalldata::Values([].span()); mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 1); - mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, 422, 1); + mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any, 422, 1); let thing = dispatcher.get_thing(); assert_eq!(thing, 421, "Specific calldata"); @@ -546,9 +546,9 @@ fn mock_calls_when_mixed() { let dispatcher = IMockCheckerDispatcher { contract_address }; let mock_ret_data = 421; - let expected_calldata = MockCallData::Values([].span()); + let expected_calldata = MockCalldata::Values([].span()); mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_ret_data, 3); - mock_call_when(contract_address, selector!("get_thing"), MockCallData::Any, 422, 2); + mock_call_when(contract_address, selector!("get_thing"), MockCalldata::Any, 422, 2); let thing = dispatcher.get_thing(); assert_eq!(thing, 421, "1st Specific calldata"); diff --git a/docs/src/appendix/cheatcodes/mock_call_when.md b/docs/src/appendix/cheatcodes/mock_call_when.md index 0b13676b02..2bd1a5ad54 100644 --- a/docs/src/appendix/cheatcodes/mock_call_when.md +++ b/docs/src/appendix/cheatcodes/mock_call_when.md @@ -2,22 +2,22 @@ Cheatcodes mocking contract entry point calls: -## `MockCallData` +## `MockCalldata` ```rust -pub enum MockCallData { +pub enum MockCalldata { Any, Values: Span, } ``` -`MockCallData` is an enum used to specify for which calldata the contract entry point will be mocked. +`MockCalldata` is an enum used to specify for which calldata the contract entry point will be mocked. - `Any` mock the contract entry point for any calldata. - `Values` mock the contract entry point only for this calldata. ## `mock_call_when` > `fn mock_call_when, impl TDestruct: Destruct>( -> contract_address: ContractAddress, function_selector: felt252, calldata: MockCallData, ret_data: T, n_times: u32 +> contract_address: ContractAddress, function_selector: felt252, calldata: MockCalldata, ret_data: T, n_times: u32 > )` Mocks contract call to a `function_selector` of a contract at the given address, with the given calldata, for `n_times` first calls that are made @@ -29,7 +29,7 @@ Note that the function is not meant for mocking internal calls - it works only f ## `start_mock_call_when` > `fn start_mock_call, impl TDestruct: Destruct>( -> contract_address: ContractAddress, function_selector: felt252, calldata: MockCallData, ret_data: T +> contract_address: ContractAddress, function_selector: felt252, calldata: MockCalldata, ret_data: T > )` Mocks contract call to a `function_selector` of a contract at the given address, with the given calldata, indefinitely. @@ -38,6 +38,6 @@ See `mock_call_when` for comprehensive definition of how it can be used. ### `stop_mock_call_when` -> `fn stop_mock_call_when(contract_address: ContractAddress, function_selector: felt252, calldata: MockCallData)` +> `fn stop_mock_call_when(contract_address: ContractAddress, function_selector: felt252, calldata: MockCalldata)` Cancels the `mock_call_when` / `start_mock_call_when` for the function `function_selector` of a contract at the given addressn with the given calldata diff --git a/snforge_std/src/cheatcodes.cairo b/snforge_std/src/cheatcodes.cairo index d2ed207e41..55d539c970 100644 --- a/snforge_std/src/cheatcodes.cairo +++ b/snforge_std/src/cheatcodes.cairo @@ -20,35 +20,15 @@ pub enum CheatSpan { TargetCalls: usize, } -/// Enum used to specify the call data that should be matched when mocking a contract call. -#[derive(Copy, Drop, PartialEq, Clone, Debug)] -pub enum MockCallData { - /// Matches any call data. +/// Enum used to specify the calldata that should be matched when mocking a contract call. +#[derive(Copy, Drop, PartialEq, Clone, Debug, Serde)] +pub enum MockCalldata { + /// Matches any calldata. Any, - /// Matches the specified serialized call data. + /// Matches the specified serialized calldata. Values: Span, } -impl MockCallDataSerde of Serde { - fn deserialize(ref serialized: Span) -> Option { - let value: Option>> = Serde::deserialize(ref serialized); - - match value { - Option::None => Option::None, - Option::Some(call_data) => match call_data { - Option::None => Option::Some(MockCallData::Any), - Option::Some(data) => Option::Some(MockCallData::Values(data)), - }, - } - } - - fn serialize(self: @MockCallData, ref output: Array) { - match self { - MockCallData::Any => Option::>::None.serialize(ref output), - MockCallData::Values(data) => Option::Some(*data).serialize(ref output), - } - } -} pub fn test_selector() -> felt252 { // Result of selector!("TEST_CONTRACT_SELECTOR") since `selector!` macro requires dependency on @@ -76,7 +56,7 @@ pub fn test_address() -> ContractAddress { pub fn mock_call, impl TDestruct: Destruct>( contract_address: ContractAddress, function_selector: felt252, ret_data: T, n_times: u32 ) { - mock_call_when(contract_address, function_selector, MockCallData::Any, ret_data, n_times) + mock_call_when(contract_address, function_selector, MockCalldata::Any, ret_data, n_times) } /// Mocks contract call to a function of a contract at the given address, indefinitely. @@ -88,7 +68,7 @@ pub fn mock_call, impl TDestruct: Destruct pub fn start_mock_call, impl TDestruct: Destruct>( contract_address: ContractAddress, function_selector: felt252, ret_data: T ) { - start_mock_call_when(contract_address, function_selector, MockCallData::Any, ret_data) + start_mock_call_when(contract_address, function_selector, MockCalldata::Any, ret_data) } /// Cancels the `mock_call` / `start_mock_call` for the function with given name and contract @@ -97,7 +77,7 @@ pub fn start_mock_call, impl TDestruct: De /// - `function_selector` - hashed name of the target function (can be obtained with `selector!` /// macro) pub fn stop_mock_call(contract_address: ContractAddress, function_selector: felt252,) { - stop_mock_call_when(contract_address, function_selector, MockCallData::Any) + stop_mock_call_when(contract_address, function_selector, MockCalldata::Any) } /// Mocks contract call to a `function_selector` of a contract at the given address, for `n_times` @@ -110,13 +90,13 @@ pub fn stop_mock_call(contract_address: ContractAddress, function_selector: felt /// - `contract_address` - target contract address /// - `function_selector` - hashed name of the target function (can be obtained with `selector!` /// macro) -/// - `call_data` - matching call data +/// - `calldata` - matching calldata /// - `ret_data` - data to return by the function `function_selector` /// - `n_times` - number of calls to mock the function for pub fn mock_call_when, impl TDestruct: Destruct>( contract_address: ContractAddress, function_selector: felt252, - call_data: MockCallData, + calldata: MockCalldata, ret_data: T, n_times: u32 ) { @@ -124,7 +104,7 @@ pub fn mock_call_when, impl TDestruct: Des let contract_address_felt: felt252 = contract_address.into(); let mut inputs = array![contract_address_felt, function_selector]; - call_data.serialize(ref inputs); + calldata.serialize(ref inputs); CheatSpan::TargetCalls(n_times).serialize(ref inputs); let mut ret_data_arr = ArrayTrait::new(); @@ -141,17 +121,17 @@ pub fn mock_call_when, impl TDestruct: Des /// - `contract_address` - targeted contracts' address /// - `function_selector` - hashed name of the target function (can be obtained with `selector!` /// macro) -/// - `call_data` - matching call data +/// - `calldata` - matching calldata /// - `ret_data` - data to be returned by the function pub fn start_mock_call_when, impl TDestruct: Destruct>( contract_address: ContractAddress, function_selector: felt252, - call_data: MockCallData, + calldata: MockCalldata, ret_data: T ) { let contract_address_felt: felt252 = contract_address.into(); let mut inputs = array![contract_address_felt, function_selector]; - call_data.serialize(ref inputs); + calldata.serialize(ref inputs); CheatSpan::Indefinite.serialize(ref inputs); let mut ret_data_arr = ArrayTrait::new(); @@ -166,14 +146,14 @@ pub fn start_mock_call_when, impl TDestruc /// contract address. /// - `contract_address` - targeted contracts' address /// - `function_selector` - hashed name of the target function (can be obtained with `selector!` -/// - `call_data` - matching call data +/// - `calldata` - matching calldata /// macro) pub fn stop_mock_call_when( - contract_address: ContractAddress, function_selector: felt252, call_data: MockCallData + contract_address: ContractAddress, function_selector: felt252, calldata: MockCalldata ) { let contract_address_felt: felt252 = contract_address.into(); let mut inputs = array![contract_address_felt, function_selector]; - call_data.serialize(ref inputs); + calldata.serialize(ref inputs); execute_cheatcode_and_deserialize::<'stop_mock_call', ()>(inputs.span()); } diff --git a/snforge_std/src/lib.cairo b/snforge_std/src/lib.cairo index ab2532bcab..9b665bd9a6 100644 --- a/snforge_std/src/lib.cairo +++ b/snforge_std/src/lib.cairo @@ -30,7 +30,7 @@ pub use cheatcodes::CheatSpan; pub use cheatcodes::ReplaceBytecodeError; pub use cheatcodes::test_address; pub use cheatcodes::test_selector; -pub use cheatcodes::MockCallData; +pub use cheatcodes::MockCalldata; pub use cheatcodes::mock_call; pub use cheatcodes::start_mock_call; pub use cheatcodes::stop_mock_call; From 90cff00a92a643edb23367abadb3623c46b0860d Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Fri, 2 May 2025 11:32:27 +0200 Subject: [PATCH 11/18] fix: add missing argument for run_test_case --- crates/forge/tests/integration/mock_call.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/forge/tests/integration/mock_call.rs b/crates/forge/tests/integration/mock_call.rs index a2e2ecec22..225501eefe 100644 --- a/crates/forge/tests/integration/mock_call.rs +++ b/crates/forge/tests/integration/mock_call.rs @@ -284,7 +284,7 @@ fn mock_call_when_simple() { .unwrap() ); - let result = run_test_case(&test); + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); assert_passed(&result); } @@ -386,7 +386,7 @@ fn mock_call_when_complex_types() { .unwrap() ); - let result = run_test_case(&test); + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); assert_passed(&result); } @@ -496,7 +496,7 @@ fn mock_calls_when() { .unwrap() ); - let result = run_test_case(&test); + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); assert_passed(&result); } @@ -578,6 +578,6 @@ fn mock_calls_when_mixed() { .unwrap() ); - let result = run_test_case(&test); + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); assert_passed(&result); } From 2cd43a499a4a0fd7455012a39f09a3d89dddd9ab Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Fri, 2 May 2025 14:27:02 +0200 Subject: [PATCH 12/18] scarb fmt --- snforge_std/src/cheatcodes.cairo | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/snforge_std/src/cheatcodes.cairo b/snforge_std/src/cheatcodes.cairo index 06f60c6858..7d0777a66d 100644 --- a/snforge_std/src/cheatcodes.cairo +++ b/snforge_std/src/cheatcodes.cairo @@ -68,7 +68,7 @@ pub fn mock_call, impl TDestruct: Destruct /// macro) /// - `ret_data` - data to be returned by the function pub fn start_mock_call, impl TDestruct: Destruct>( - contract_address: ContractAddress, function_selector: felt252, ret_data: T + contract_address: ContractAddress, function_selector: felt252, ret_data: T, ) { start_mock_call_when(contract_address, function_selector, MockCalldata::Any, ret_data) } @@ -78,7 +78,7 @@ pub fn start_mock_call, impl TDestruct: De /// - `contract_address` - targeted contracts' address /// - `function_selector` - hashed name of the target function (can be obtained with `selector!` /// macro) -pub fn stop_mock_call(contract_address: ContractAddress, function_selector: felt252,) { +pub fn stop_mock_call(contract_address: ContractAddress, function_selector: felt252) { stop_mock_call_when(contract_address, function_selector, MockCalldata::Any) } @@ -100,7 +100,7 @@ pub fn mock_call_when, impl TDestruct: Des function_selector: felt252, calldata: MockCalldata, ret_data: T, - n_times: u32 + n_times: u32, ) { assert!(n_times > 0, "cannot mock_call 0 times, n_times argument must be greater than 0"); @@ -151,7 +151,7 @@ pub fn start_mock_call_when, impl TDestruc /// - `calldata` - matching calldata /// macro) pub fn stop_mock_call_when( - contract_address: ContractAddress, function_selector: felt252, calldata: MockCalldata + contract_address: ContractAddress, function_selector: felt252, calldata: MockCalldata, ) { let contract_address_felt: felt252 = contract_address.into(); let mut inputs = array![contract_address_felt, function_selector]; From 2c5ec936014c8534a2d74fdc8197bb57e9121752 Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Sat, 17 May 2025 17:49:55 +0200 Subject: [PATCH 13/18] remove workaround for #2927 --- .../execution/entry_point.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs b/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs index 4f59227909..1cfdd19790 100644 --- a/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs +++ b/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs @@ -3,7 +3,7 @@ use crate::runtime_extensions::call_to_blockifier_runtime_extension::execution:: use crate::runtime_extensions::call_to_blockifier_runtime_extension::rpc::{AddressOrClassHash, CallResult}; use crate::runtime_extensions::call_to_blockifier_runtime_extension::CheatnetState; use crate::runtime_extensions::common::{get_relocated_vm_trace, get_syscalls_gas_consumed, sum_syscall_usage}; -use crate::state::{CallTrace, CallTraceNode, CheatSpan, CheatStatus}; +use crate::state::{CallTrace, CallTraceNode, CheatStatus}; use blockifier::execution::call_info::{CallExecution, Retdata}; use blockifier::execution::contract_class::{RunnableCompiledClass, TrackedResource}; use blockifier::execution::syscalls::hint_processor::{SyscallUsageMap, ENTRYPOINT_NOT_FOUND_ERROR, OUT_OF_GAS_ERROR}; @@ -410,9 +410,6 @@ fn get_mocked_function_cheat_status<'a>( let key_zero = (call.entry_point_selector, Felt::zero()); match contract_functions.get(&key) { - Some(CheatStatus::Cheated(_, CheatSpan::TargetCalls(0))) => { - contract_functions.get_mut(&key_zero) - } Some(CheatStatus::Cheated(_, _)) => contract_functions.get_mut(&key), _ => contract_functions.get_mut(&key_zero), } From 3aeb49647d949505acbc21c6201e0518aab5bf9b Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Tue, 15 Jul 2025 08:28:32 +0200 Subject: [PATCH 14/18] add test cases with interleaved start/stop mock_call and mock_call_when --- crates/forge/tests/integration/mock_call.rs | 164 ++++++++++++++++++++ 1 file changed, 164 insertions(+) diff --git a/crates/forge/tests/integration/mock_call.rs b/crates/forge/tests/integration/mock_call.rs index 225501eefe..4ae58707f0 100644 --- a/crates/forge/tests/integration/mock_call.rs +++ b/crates/forge/tests/integration/mock_call.rs @@ -581,3 +581,167 @@ fn mock_calls_when_mixed() { let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); assert_passed(&result); } + +#[test] +fn mock_calls_start_stop_when_mixed() { + let test = test_case!( + indoc!( r#" + use result::ResultTrait; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCalldata, start_mock_call, start_mock_call_when, stop_mock_call, stop_mock_call_when}; + + #[starknet::interface] + trait IMockChecker { + fn get_thing(ref self: TContractState) -> felt252; + } + + #[test] + fn mock_calls_start_stop_when_mixed() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + let mock_when_ret_data = 422; + + let expected_calldata = MockCalldata::Values([].span()); + start_mock_call(contract_address, selector!("get_thing"), mock_ret_data); + start_mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_when_ret_data); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_when_ret_data, "1st Mock call when"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_when_ret_data, "2nd Mock call when"); + stop_mock_call_when(contract_address, selector!("get_thing"), expected_calldata); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_ret_data, "Mock call"); + + stop_mock_call(contract_address, selector!("get_thing")); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + + #[test] + fn mock_calls_start_stop_when_count_mixed() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + let mock_when_ret_data = 422; + + let expected_calldata = MockCalldata::Values([].span()); + start_mock_call(contract_address, selector!("get_thing"), mock_ret_data); + mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_when_ret_data, 2); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_when_ret_data, "1st Mock call when"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_when_ret_data, "2nd Mock call when"); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_ret_data, "Mock call"); + + stop_mock_call(contract_address, selector!("get_thing")); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + "#), + Contract::from_code_path( + "MockChecker".to_string(), + Path::new("tests/data/contracts/mock_checker.cairo"), + ) + .unwrap() + ); + + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); + assert_passed(&result); +} + +#[test] +fn mock_calls_start_stop_when_interleaved() { + let test = test_case!( + indoc!( r#" + use result::ResultTrait; + use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, MockCalldata, start_mock_call, start_mock_call_when, stop_mock_call, stop_mock_call_when}; + + #[starknet::interface] + trait IMockChecker { + fn get_thing(ref self: TContractState) -> felt252; + } + + #[test] + fn mock_calls_start_when_and_stop() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_when_ret_data = 422; + + let expected_calldata = MockCalldata::Values([].span()); + start_mock_call_when(contract_address, selector!("get_thing"), expected_calldata, mock_when_ret_data); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_when_ret_data, "1st Mock call when"); + + stop_mock_call(contract_address, selector!("get_thing")); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_when_ret_data, "2nd Mock call when"); + stop_mock_call_when(contract_address, selector!("get_thing"), expected_calldata); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + + #[test] + fn mock_calls_start_and_stop_when() { + let calldata = array![420]; + + let contract = declare("MockChecker").unwrap().contract_class(); + let (contract_address, _) = contract.deploy(@calldata).unwrap(); + + let dispatcher = IMockCheckerDispatcher { contract_address }; + + let mock_ret_data = 421; + + let expected_calldata = MockCalldata::Values([].span()); + + start_mock_call(contract_address, selector!("get_thing"), mock_ret_data); + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_ret_data, "Mock call"); + + stop_mock_call_when(contract_address, selector!("get_thing"), expected_calldata); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, mock_ret_data, "Mock call"); + + stop_mock_call(contract_address, selector!("get_thing")); + + let thing = dispatcher.get_thing(); + assert_eq!(thing, 420); + } + "#), + Contract::from_code_path( + "MockChecker".to_string(), + Path::new("tests/data/contracts/mock_checker.cairo"), + ) + .unwrap() + ); + + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); + assert_passed(&result); +} \ No newline at end of file From f235bd7a89bd3eb2eef78c18bb1bba059bc2bc1d Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Tue, 15 Jul 2025 08:36:57 +0200 Subject: [PATCH 15/18] cargo fmt --- .../execution/entry_point.rs | 2 +- crates/forge/tests/integration/mock_call.rs | 26 +++++++++++-------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs b/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs index a8c01b8052..f30ea650ef 100644 --- a/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs +++ b/crates/cheatnet/src/runtime_extensions/call_to_blockifier_runtime_extension/execution/entry_point.rs @@ -28,10 +28,10 @@ use blockifier::{ state::state_api::State, }; use cairo_vm::vm::runners::cairo_runner::{CairoRunner, ExecutionResources}; -use num_traits::Zero; use cairo_vm::vm::trace::trace_entry::RelocatedTraceEntry; use conversions::FromConv; use conversions::string::TryFromHexStr; +use num_traits::Zero; use shared::vm::VirtualMachineExt; use starknet_api::{ contract_class::EntryPointType, diff --git a/crates/forge/tests/integration/mock_call.rs b/crates/forge/tests/integration/mock_call.rs index 4ae58707f0..34a00a616e 100644 --- a/crates/forge/tests/integration/mock_call.rs +++ b/crates/forge/tests/integration/mock_call.rs @@ -585,7 +585,8 @@ fn mock_calls_when_mixed() { #[test] fn mock_calls_start_stop_when_mixed() { let test = test_case!( - indoc!( r#" + indoc!( + r#" use result::ResultTrait; use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, mock_call_when, MockCalldata, start_mock_call, start_mock_call_when, stop_mock_call, stop_mock_call_when}; @@ -656,14 +657,15 @@ fn mock_calls_start_stop_when_mixed() { let thing = dispatcher.get_thing(); assert_eq!(thing, 420); } - "#), - Contract::from_code_path( + "# + ), + Contract::from_code_path( "MockChecker".to_string(), Path::new("tests/data/contracts/mock_checker.cairo"), ) .unwrap() - ); - + ); + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); assert_passed(&result); } @@ -671,7 +673,8 @@ fn mock_calls_start_stop_when_mixed() { #[test] fn mock_calls_start_stop_when_interleaved() { let test = test_case!( - indoc!( r#" + indoc!( + r#" use result::ResultTrait; use snforge_std::{ declare, ContractClassTrait, DeclareResultTrait, MockCalldata, start_mock_call, start_mock_call_when, stop_mock_call, stop_mock_call_when}; @@ -734,14 +737,15 @@ fn mock_calls_start_stop_when_interleaved() { let thing = dispatcher.get_thing(); assert_eq!(thing, 420); } - "#), - Contract::from_code_path( + "# + ), + Contract::from_code_path( "MockChecker".to_string(), Path::new("tests/data/contracts/mock_checker.cairo"), ) .unwrap() - ); - + ); + let result = run_test_case(&test, ForgeTrackedResource::CairoSteps); assert_passed(&result); -} \ No newline at end of file +} From 3ae2a2815b24cc86b71a4b21fd76d85063bc5250 Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Wed, 6 Aug 2025 01:02:53 +0200 Subject: [PATCH 16/18] add missing use for cheatcodes --- snforge_std/src/lib.cairo | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/snforge_std/src/lib.cairo b/snforge_std/src/lib.cairo index f819af4ebb..9c9690b829 100644 --- a/snforge_std/src/lib.cairo +++ b/snforge_std/src/lib.cairo @@ -108,8 +108,9 @@ pub use cheatcodes::message_to_l1::{ pub use cheatcodes::storage::store; pub use cheatcodes::storage::{interact_with_state, load, map_entry_address}; pub use cheatcodes::{ - ReplaceBytecodeError, mock_call, replace_bytecode, start_mock_call, stop_mock_call, - test_address, test_selector, + MockCalldata, ReplaceBytecodeError, mock_call, mock_call_when, replace_bytecode, + start_mock_call, start_mock_call_when, stop_mock_call, stop_mock_call_when, test_address, + test_selector, }; pub mod byte_array; From 02604f9d197e2d7ce1d9dc2945e569609445789b Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Tue, 16 Sep 2025 10:33:02 +0200 Subject: [PATCH 17/18] docs: apply PR suggestion --- CHANGELOG.md | 2 +- docs/src/appendix/cheatcodes/mock_call_when.md | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b1104b369..3af7aaf290 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `snforge` now supports [oracles](https://docs.swmansion.com/cairo-oracle/) with `--experimental-oracles` flag. - `--trace-components` flag to allow selecting which components of the trace to do display. Read more [here](https://foundry-rs.github.io/starknet-foundry/snforge-advanced-features/debugging.html#trace-components) +- `mock_call_when`, `start_mock_call_when`, `stop_mock_call_when` cheatcodes. ### Cast @@ -362,7 +363,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Rust is no longer required to use `snforge` if using Scarb >= 2.10.0 on supported platforms - precompiled `snforge_scarb_plugin` plugin binaries are now published to [package registry](https://scarbs.xyz) for new versions. - Added a suggestion for using the `--max-n-steps` flag when the Cairo VM returns the error: `Could not reach the end of the program. RunResources has no remaining steps`. -- `mock_call_when`, `start_mock_call_when`, `stop_mock_call_when` cheatcodes. #### Fixed diff --git a/docs/src/appendix/cheatcodes/mock_call_when.md b/docs/src/appendix/cheatcodes/mock_call_when.md index 2bd1a5ad54..e1330c66e0 100644 --- a/docs/src/appendix/cheatcodes/mock_call_when.md +++ b/docs/src/appendix/cheatcodes/mock_call_when.md @@ -1,6 +1,6 @@ # `mock_call_when` -Cheatcodes mocking contract entry point calls: +Cheatcodes mocking contract entry point calls based on calldata: ## `MockCalldata` @@ -37,7 +37,6 @@ See `mock_call_when` for comprehensive definition of how it can be used. ### `stop_mock_call_when` - > `fn stop_mock_call_when(contract_address: ContractAddress, function_selector: felt252, calldata: MockCalldata)` Cancels the `mock_call_when` / `start_mock_call_when` for the function `function_selector` of a contract at the given addressn with the given calldata From 86811c856449dc6a7b4ce1a62d74eab272420c5e Mon Sep 17 00:00:00 2001 From: Patrice Tisserand Date: Tue, 16 Sep 2025 10:45:48 +0200 Subject: [PATCH 18/18] update CHANGELOG --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5763edef01..e6ead39cfa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Forge +#### Added + +- `mock_call_when`, `start_mock_call_when`, `stop_mock_call_when` cheatcodes. + #### Removed - Possibility to use `#[available_gas]` with unnamed argument. Use named arguments instead, e.g. `#[available_gas(l2_gas: 5)]`. @@ -22,7 +26,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for `meta_tx_v0` syscall with cheatcode compatibility - `snforge` now supports [oracles](https://docs.swmansion.com/cairo-oracle/) with `--experimental-oracles` flag. - `--trace-components` flag to allow selecting which components of the trace to do display. Read more [here](https://foundry-rs.github.io/starknet-foundry/snforge-advanced-features/debugging.html#trace-components) -- `mock_call_when`, `start_mock_call_when`, `stop_mock_call_when` cheatcodes. ### Cast