33import math
44from abc import ABC , abstractmethod
55from dataclasses import dataclass , field
6- from typing import Callable , ClassVar , Dict , Generator , List , Sequence , Type
6+ from typing import Any , Callable , ClassVar , Dict , Generator , List , Sequence , Type
77
88import pytest
99from pydantic import ConfigDict , Field
1010
1111from ethereum_clis import TransitionTool
12- from ethereum_test_base_types import HexNumber
12+ from ethereum_test_base_types import Address , HexNumber
1313from ethereum_test_exceptions import BlockException , TransactionException
1414from ethereum_test_execution import (
1515 BaseExecute ,
@@ -40,32 +40,52 @@ class BenchmarkCodeGenerator(ABC):
4040
4141 attack_block : Bytecode
4242 setup : Bytecode = field (default_factory = Bytecode )
43+ cleanup : Bytecode = field (default_factory = Bytecode )
44+ tx_kwargs : Dict [str , Any ] = field (default_factory = dict )
45+ _contract_address : Address | None = None
4346
4447 @abstractmethod
45- def deploy_contracts (self , pre : Alloc , fork : Fork ) -> None :
48+ def deploy_contracts (self , * , pre : Alloc , fork : Fork ) -> Address :
4649 """Deploy any contracts needed for the benchmark."""
4750 ...
4851
49- @abstractmethod
50- def generate_transaction (self , pre : Alloc , gas_limit : int ) -> Transaction :
51- """Generate a transaction with the specified gas limit."""
52- ...
52+ def generate_transaction (self , * , pre : Alloc , gas_benchmark_value : int ) -> Transaction :
53+ """Generate transaction that executes the looping contract."""
54+ assert self ._contract_address is not None
55+ if "gas_limit" not in self .tx_kwargs :
56+ self .tx_kwargs ["gas_limit" ] = gas_benchmark_value
57+
58+ return Transaction (
59+ to = self ._contract_address ,
60+ sender = pre .fund_eoa (),
61+ ** self .tx_kwargs ,
62+ )
5363
5464 def generate_repeated_code (
55- self , repeated_code : Bytecode , setup : Bytecode , fork : Fork
65+ self ,
66+ * ,
67+ repeated_code : Bytecode ,
68+ setup : Bytecode | None = None ,
69+ cleanup : Bytecode | None = None ,
70+ fork : Fork ,
5671 ) -> Bytecode :
5772 """
5873 Calculate the maximum number of iterations that
5974 can fit in the code size limit.
6075 """
6176 assert len (repeated_code ) > 0 , "repeated_code cannot be empty"
6277 max_code_size = fork .max_code_size ()
63-
64- overhead = len (setup ) + len (Op .JUMPDEST ) + len (Op .JUMP (len (setup )))
78+ if setup is None :
79+ setup = Bytecode ()
80+ if cleanup is None :
81+ cleanup = Bytecode ()
82+ overhead = len (setup ) + len (Op .JUMPDEST ) + len (cleanup ) + len (Op .JUMP (len (setup )))
6583 available_space = max_code_size - overhead
6684 max_iterations = available_space // len (repeated_code )
6785
68- code = setup + Op .JUMPDEST + repeated_code * max_iterations + Op .JUMP (len (setup ))
86+ # TODO: Unify the PUSH0 and PUSH1 usage.
87+ code = setup + Op .JUMPDEST + repeated_code * max_iterations + cleanup
88+ code += Op .JUMP (len (setup )) if len (setup ) > 0 else Op .PUSH0 + Op .JUMP
6989 self ._validate_code_size (code , fork )
7090
7191 return code
@@ -84,9 +104,10 @@ class BenchmarkTest(BaseTest):
84104
85105 model_config = ConfigDict (extra = "forbid" )
86106
87- pre : Alloc
107+ pre : Alloc = Field ( default_factory = Alloc )
88108 post : Alloc = Field (default_factory = Alloc )
89109 tx : Transaction | None = None
110+ setup_blocks : List [Block ] = Field (default_factory = list )
90111 blocks : List [Block ] | None = None
91112 block_exception : (
92113 List [TransactionException | BlockException ] | TransactionException | BlockException | None
@@ -115,6 +136,14 @@ class BenchmarkTest(BaseTest):
115136 "blockchain_test_only" : "Only generate a blockchain test fixture" ,
116137 }
117138
139+ def model_post_init (self , __context : Any , / ) -> None :
140+ """
141+ Model post-init to assert that the custom pre-allocation was
142+ provided and the default was not used.
143+ """
144+ super ().model_post_init (__context )
145+ assert "pre" in self .model_fields_set , "pre allocation was not provided"
146+
118147 @classmethod
119148 def pytest_parameter_name (cls ) -> str :
120149 """
@@ -178,9 +207,11 @@ def generate_blocks_from_code_generator(self, fork: Fork) -> List[Block]:
178207 if self .code_generator is None :
179208 raise Exception ("Code generator is not set" )
180209
181- self .code_generator .deploy_contracts (self .pre , fork )
210+ self .code_generator .deploy_contracts (pre = self .pre , fork = fork )
182211 gas_limit = fork .transaction_gas_limit_cap () or self .gas_benchmark_value
183- benchmark_tx = self .code_generator .generate_transaction (self .pre , gas_limit )
212+ benchmark_tx = self .code_generator .generate_transaction (
213+ pre = self .pre , gas_benchmark_value = gas_limit
214+ )
184215
185216 execution_txs = self .split_transaction (benchmark_tx , gas_limit )
186217 execution_block = Block (txs = execution_txs )
@@ -204,39 +235,34 @@ def generate_blockchain_test(self, fork: Fork) -> BlockchainTest:
204235 f"Exactly one must be set, but got { len (set_props )} : { ', ' .join (set_props )} "
205236 )
206237
238+ blocks : List [Block ] = self .setup_blocks
239+
207240 if self .code_generator is not None :
208241 generated_blocks = self .generate_blocks_from_code_generator (fork )
209- return BlockchainTest .from_test (
210- base_test = self ,
211- genesis_environment = self .env ,
212- pre = self .pre ,
213- post = self .post ,
214- blocks = generated_blocks ,
215- )
242+ blocks += generated_blocks
243+
216244 elif self .blocks is not None :
217- return BlockchainTest .from_test (
218- base_test = self ,
219- genesis_environment = self .env ,
220- pre = self .pre ,
221- post = self .post ,
222- blocks = self .blocks ,
223- )
245+ blocks += self .blocks
246+
224247 elif self .tx is not None :
225248 gas_limit = fork .transaction_gas_limit_cap () or self .gas_benchmark_value
226249
227250 transactions = self .split_transaction (self .tx , gas_limit )
228251
229- blocks = [ Block (txs = transactions )]
252+ blocks . append ( Block (txs = transactions ))
230253
231- return BlockchainTest .from_test (
232- base_test = self ,
233- pre = self .pre ,
234- post = self .post ,
235- blocks = blocks ,
236- genesis_environment = self .env ,
237- )
238254 else :
239- raise ValueError ("Cannot create BlockchainTest without transactions or blocks" )
255+ raise ValueError (
256+ "Cannot create BlockchainTest without a code generator, transactions, or blocks"
257+ )
258+
259+ return BlockchainTest .from_test (
260+ base_test = self ,
261+ genesis_environment = self .env ,
262+ pre = self .pre ,
263+ post = self .post ,
264+ blocks = blocks ,
265+ )
240266
241267 def generate (
242268 self ,
0 commit comments