Skip to content

Commit 30b4f76

Browse files
refactor: update benchmark code generator and test wrapper
1 parent 4e61bca commit 30b4f76

File tree

8 files changed

+266
-94
lines changed

8 files changed

+266
-94
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Benchmark code generator classes for creating optimized bytecode patterns."""
2+
3+
from .benchmark_code_generator import (
4+
BenchmarkCodeGenerator,
5+
ExtCallGenerator,
6+
JumpLoopGenerator,
7+
)
8+
9+
__all__ = (
10+
"BenchmarkCodeGenerator",
11+
"ExtCallGenerator",
12+
"JumpLoopGenerator",
13+
)
Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
"""Benchmark code generator classes for creating optimized bytecode patterns."""
22

33
from abc import ABC, abstractmethod
4-
from typing import Optional
4+
from dataclasses import dataclass, field
55

66
from ethereum_test_forks import Fork
7-
from ethereum_test_tools import Alloc, Bytecode, Transaction
8-
from ethereum_test_tools.vm.opcode import Opcodes as Op
7+
from ethereum_test_types import Alloc, Transaction
8+
from ethereum_test_vm import Bytecode
9+
from ethereum_test_vm.opcode import Opcodes as Op
910

1011

12+
@dataclass
1113
class BenchmarkCodeGenerator(ABC):
1214
"""Abstract base class for generating benchmark bytecode."""
1315

14-
def __init__(
15-
self,
16-
fork: Fork,
17-
attack_block: Bytecode,
18-
setup: Optional[Bytecode] = None,
19-
):
20-
"""Initialize with fork, attack block, and optional setup bytecode."""
21-
self.fork = fork
22-
self.setup = setup or Bytecode()
23-
self.attack_block = attack_block
16+
fork: Fork
17+
attack_block: Bytecode
18+
setup: Bytecode = field(default_factory=Bytecode)
19+
20+
@abstractmethod
21+
def deploy_contracts(self, pre: Alloc) -> None:
22+
"""Deploy any contracts needed for the benchmark."""
23+
pass
2424

2525
@abstractmethod
2626
def generate_transaction(self, pre: Alloc, gas_limit: int) -> Transaction:
@@ -29,14 +29,14 @@ def generate_transaction(self, pre: Alloc, gas_limit: int) -> Transaction:
2929

3030
def generate_repeated_code(self, repeated_code: Bytecode, setup: Bytecode) -> Bytecode:
3131
"""Calculate the maximum number of iterations that can fit in the code size limit."""
32+
assert len(repeated_code) > 0, "repeated_code cannot be empty"
3233
max_code_size = self.fork.max_code_size()
3334

34-
overhead = len(Op.JUMPDEST) + len(Op.JUMP(len(setup)))
35+
overhead = len(setup) + len(Op.JUMPDEST) + len(Op.JUMP(len(setup)))
3536
available_space = max_code_size - overhead
36-
max_iterations = available_space // len(repeated_code) if len(repeated_code) > 0 else 0
37+
max_iterations = available_space // len(repeated_code)
3738

3839
code = setup + Op.JUMPDEST + repeated_code * max_iterations + Op.JUMP(len(setup))
39-
4040
self._validate_code_size(code)
4141

4242
return code
@@ -50,47 +50,62 @@ def _validate_code_size(self, code: Bytecode) -> None:
5050
)
5151

5252

53+
@dataclass
5354
class JumpLoopGenerator(BenchmarkCodeGenerator):
5455
"""Generates bytecode that loops execution using JUMP operations."""
5556

56-
def generate_transaction(self, pre: Alloc, gas_limit: int) -> Transaction:
57-
"""Generate transaction with looping bytecode pattern."""
57+
def deploy_contracts(self, pre: Alloc) -> None:
58+
"""Deploy the looping contract."""
5859
# Benchmark Test Structure:
5960
# setup + JUMPDEST + attack + attack + ... + attack + JUMP(setup_length)
60-
6161
code = self.generate_repeated_code(self.attack_block, self.setup)
62+
self._contract_address = pre.deploy_contract(code=code)
63+
64+
def generate_transaction(self, pre: Alloc, gas_limit: int) -> Transaction:
65+
"""Generate transaction that executes the looping contract."""
66+
if not hasattr(self, "_contract_address"):
67+
raise ValueError("deploy_contracts must be called before generate_transaction")
6268

6369
return Transaction(
64-
to=pre.deploy_contract(code=code),
65-
gas_limit=self.fork.transaction_gas_limit_cap() or 30_000_000,
70+
to=self._contract_address,
71+
gas_limit=gas_limit,
6672
sender=pre.fund_eoa(),
6773
)
6874

6975

76+
@dataclass
7077
class ExtCallGenerator(BenchmarkCodeGenerator):
7178
"""Generates bytecode that fills the contract to maximum allowed code size."""
7279

73-
def generate_transaction(self, pre: Alloc, gas_limit: int) -> Transaction:
74-
"""Generate transaction with maximal code size coverage."""
80+
def deploy_contracts(self, pre: Alloc) -> None:
81+
"""Deploy both target and caller contracts."""
7582
# Benchmark Test Structure:
7683
# There are two contracts:
7784
# 1. The target contract that executes certain operation but not loop (e.g. PUSH)
7885
# 2. The loop contract that calls the target contract in a loop
79-
#
80-
# attack = POP(STATICCALL(GAS, target_contract_address, 0, 0, 0, 0))
81-
# setup + JUMPDEST + attack + attack + ... + attack + JUMP(setup_lengt)
82-
# This could optimize the gas consumption and increase the cycle count.
8386

8487
max_stack_height = self.fork.max_stack_height()
8588

86-
target_contract_address = pre.deploy_contract(code=self.attack_block * max_stack_height)
89+
# Deploy target contract that contains the actual attack block
90+
self._target_contract_address = pre.deploy_contract(
91+
code=self.attack_block * max_stack_height
92+
)
8793

88-
code_sequence = Op.POP(Op.STATICCALL(Op.GAS, target_contract_address, 0, 0, 0, 0))
94+
# Create caller contract that repeatedly calls the target contract
95+
# attack = POP(STATICCALL(GAS, target_contract_address, 0, 0, 0, 0))
96+
# setup + JUMPDEST + attack + attack + ... + attack + JUMP(setup_length)
97+
code_sequence = Op.POP(Op.STATICCALL(Op.GAS, self._target_contract_address, 0, 0, 0, 0))
98+
99+
caller_code = self.generate_repeated_code(code_sequence, Bytecode())
100+
self._contract_address = pre.deploy_contract(code=caller_code)
89101

90-
code = self.generate_repeated_code(code_sequence, Bytecode())
102+
def generate_transaction(self, pre: Alloc, gas_limit: int) -> Transaction:
103+
"""Generate transaction that executes the caller contract."""
104+
if not hasattr(self, "_contract_address"):
105+
raise ValueError("deploy_contracts must be called before generate_transaction")
91106

92107
return Transaction(
93-
to=pre.deploy_contract(code=code),
94-
gas_limit=self.fork.transaction_gas_limit_cap() or 30_000_000,
108+
to=self._contract_address,
109+
gas_limit=gas_limit,
95110
sender=pre.fund_eoa(),
96111
)

src/ethereum_test_specs/benchmark.py

Lines changed: 143 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Ethereum benchmark test spec definition and filler."""
22

3-
from typing import Callable, ClassVar, Dict, Generator, List, Optional, Sequence, Type
3+
from contextlib import contextmanager
4+
from contextvars import ContextVar
5+
from enum import Enum
6+
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, Sequence, Type
47

58
import pytest
6-
from pydantic import Field
9+
from pydantic import ConfigDict, Field
710

811
from ethereum_clis import TransitionTool
912
from ethereum_test_base_types import HexNumber
@@ -29,9 +32,74 @@
2932
from .blockchain import Block, BlockchainTest
3033

3134

35+
class BenchmarkPhase(Enum):
36+
"""Phases of a benchmark test."""
37+
38+
SETUP = "setup"
39+
EXECUTION = "execution"
40+
41+
42+
_current_phase: ContextVar[Optional[BenchmarkPhase]] = ContextVar("benchmark_phase", default=None)
43+
44+
45+
class BenchmarkManager:
46+
"""Context manager for managing benchmark test phases."""
47+
48+
def __init__(self):
49+
"""Initialize the BenchmarkManager with empty transaction and block lists."""
50+
self.setup_transactions: List[Transaction] = []
51+
self.setup_blocks: List[Block] = []
52+
self.execution_transactions: List[Transaction] = []
53+
self.execution_blocks: List[Block] = []
54+
55+
@contextmanager
56+
def setup(self):
57+
"""Context manager for the setup phase of a benchmark test."""
58+
token = _current_phase.set(BenchmarkPhase.SETUP)
59+
try:
60+
yield self
61+
finally:
62+
_current_phase.reset(token)
63+
64+
@contextmanager
65+
def execution(self):
66+
"""Context manager for the execution phase of a benchmark test."""
67+
token = _current_phase.set(BenchmarkPhase.EXECUTION)
68+
try:
69+
yield self
70+
finally:
71+
_current_phase.reset(token)
72+
73+
def add_transaction(self, tx: Transaction):
74+
"""Add a transaction to the current phase."""
75+
current_phase = _current_phase.get()
76+
if current_phase == BenchmarkPhase.SETUP:
77+
self.setup_transactions.append(tx)
78+
elif current_phase == BenchmarkPhase.EXECUTION:
79+
self.execution_transactions.append(tx)
80+
else:
81+
self.setup_transactions.append(tx)
82+
83+
def add_block(self, block: Block):
84+
"""Add a block to the current phase."""
85+
current_phase = _current_phase.get()
86+
if current_phase == BenchmarkPhase.SETUP:
87+
self.setup_blocks.append(block)
88+
elif current_phase == BenchmarkPhase.EXECUTION:
89+
self.execution_blocks.append(block)
90+
else:
91+
self.setup_blocks.append(block)
92+
93+
def get_current_phase(self) -> Optional[BenchmarkPhase]:
94+
"""Get the current benchmark phase."""
95+
return _current_phase.get()
96+
97+
3298
class BenchmarkTest(BaseTest):
3399
"""Test type designed specifically for benchmark test cases."""
34100

101+
model_config = ConfigDict(extra="forbid")
102+
35103
pre: Alloc
36104
post: Alloc
37105
tx: Optional[Transaction] = None
@@ -41,6 +109,9 @@ class BenchmarkTest(BaseTest):
41109
) = None
42110
env: Environment = Field(default_factory=Environment)
43111
expected_benchmark_gas_used: int | None = None
112+
gas_benchmark_value: int
113+
benchmark_manager: Optional[Any] = Field(default=None, exclude=True)
114+
code_generator: Optional[Any] = Field(default=None, exclude=True)
44115

45116
supported_fixture_formats: ClassVar[Sequence[FixtureFormat | LabeledFixtureFormat]] = [
46117
BlockchainFixture,
@@ -86,26 +157,81 @@ def get_genesis_environment(self, fork: Fork) -> Environment:
86157

87158
def split_transaction(self, tx: Transaction, gas_limit_cap: int | None) -> List[Transaction]:
88159
"""Split a transaction that exceeds the gas limit cap into multiple transactions."""
89-
if (gas_limit_cap is None) or (tx.gas_limit <= gas_limit_cap):
160+
if gas_limit_cap is None:
161+
tx.gas_limit = HexNumber(self.gas_benchmark_value)
162+
return [tx]
163+
164+
if gas_limit_cap >= self.gas_benchmark_value:
165+
tx.gas_limit = HexNumber(min(tx.gas_limit, self.gas_benchmark_value))
90166
return [tx]
91167

92-
total_gas = int(self.expected_benchmark_gas_used or self.env.gas_limit)
93-
print(f"total_gas: {total_gas}")
94-
num_splits = total_gas // gas_limit_cap
168+
remaining_gas = self.gas_benchmark_value
169+
num_splits = remaining_gas // gas_limit_cap + int(remaining_gas % gas_limit_cap)
95170

96171
split_transactions = []
97172
for i in range(num_splits):
98173
split_tx = tx.model_copy()
99-
total_gas -= gas_limit_cap
100-
split_tx.gas_limit = HexNumber(total_gas if i == num_splits - 1 else gas_limit_cap)
174+
split_tx.gas_limit = HexNumber(remaining_gas if i == num_splits - 1 else gas_limit_cap)
175+
remaining_gas -= gas_limit_cap
101176
split_tx.nonce = HexNumber(tx.nonce + i)
102177
split_transactions.append(split_tx)
103178

104179
return split_transactions
105180

181+
def generate_blocks_from_code_generator(self, fork: Fork) -> List[Block]:
182+
"""Generate blocks using the code generator."""
183+
if self.code_generator is None:
184+
return []
185+
186+
self.code_generator.deploy_contracts(self.pre)
187+
gas_limit = fork.transaction_gas_limit_cap() or self.gas_benchmark_value
188+
benchmark_tx = self.code_generator.generate_transaction(self.pre, gas_limit)
189+
190+
execution_txs = self.split_transaction(benchmark_tx, gas_limit)
191+
execution_block = Block(txs=execution_txs)
192+
193+
return [execution_block]
194+
106195
def generate_blockchain_test(self, fork: Fork) -> BlockchainTest:
107196
"""Create a BlockchainTest from this BenchmarkTest."""
108-
if self.blocks is not None:
197+
if self.code_generator is not None:
198+
generated_blocks = self.generate_blocks_from_code_generator(fork)
199+
return BlockchainTest.from_test(
200+
base_test=self,
201+
genesis_environment=self.env,
202+
pre=self.pre,
203+
post=self.post,
204+
blocks=generated_blocks,
205+
)
206+
207+
elif self.benchmark_manager is not None:
208+
all_blocks = []
209+
gas_limit = fork.transaction_gas_limit_cap() or self.gas_benchmark_value
210+
211+
if self.benchmark_manager.setup_blocks:
212+
all_blocks.extend(self.benchmark_manager.setup_blocks)
213+
elif self.benchmark_manager.setup_transactions:
214+
setup_txs = []
215+
for tx in self.benchmark_manager.setup_transactions:
216+
setup_txs.extend(self.split_transaction(tx, gas_limit))
217+
all_blocks.append(Block(txs=setup_txs))
218+
219+
if self.benchmark_manager.execution_blocks:
220+
all_blocks.extend(self.benchmark_manager.execution_blocks)
221+
elif self.benchmark_manager.execution_transactions:
222+
execution_txs = []
223+
for tx in self.benchmark_manager.execution_transactions:
224+
execution_txs.extend(self.split_transaction(tx, gas_limit))
225+
all_blocks.append(Block(txs=execution_txs))
226+
227+
return BlockchainTest.from_test(
228+
base_test=self,
229+
genesis_environment=self.env,
230+
pre=self.pre,
231+
post=self.post,
232+
blocks=all_blocks,
233+
)
234+
elif self.blocks is not None:
109235
return BlockchainTest.from_test(
110236
base_test=self,
111237
genesis_environment=self.env,
@@ -114,9 +240,9 @@ def generate_blockchain_test(self, fork: Fork) -> BlockchainTest:
114240
blocks=self.blocks,
115241
)
116242
elif self.tx is not None:
117-
gas_limit_cap = fork.transaction_gas_limit_cap()
243+
gas_limit = fork.transaction_gas_limit_cap() or self.gas_benchmark_value
118244

119-
transactions = self.split_transaction(self.tx, gas_limit_cap)
245+
transactions = self.split_transaction(self.tx, gas_limit)
120246

121247
blocks = [Block(txs=transactions)]
122248

@@ -129,7 +255,7 @@ def generate_blockchain_test(self, fork: Fork) -> BlockchainTest:
129255
)
130256
else:
131257
raise ValueError(
132-
"Cannot create BlockchainTest without transactions, blocks, or code_generator"
258+
"Cannot create BlockchainTest without transactions, blocks, or benchmark_manager"
133259
)
134260

135261
def generate(
@@ -162,5 +288,10 @@ def execute(
162288
raise Exception(f"Unsupported execute format: {execute_format}")
163289

164290

291+
def create_benchmark_manager() -> BenchmarkManager:
292+
"""Create a new BenchmarkManager instance for phase-aware benchmark testing."""
293+
return BenchmarkManager()
294+
295+
165296
BenchmarkTestSpec = Callable[[str], Generator[BenchmarkTest, None, None]]
166297
BenchmarkTestFiller = Type[BenchmarkTest]

src/ethereum_test_tools/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
TestPrivateKey2,
1717
)
1818
from ethereum_test_base_types.reference_spec import ReferenceSpec, ReferenceSpecTypes
19+
from ethereum_test_benchmark import (
20+
BenchmarkCodeGenerator,
21+
ExtCallGenerator,
22+
JumpLoopGenerator,
23+
)
1924
from ethereum_test_exceptions import (
2025
BlockException,
2126
EngineAPIError,
@@ -86,11 +91,6 @@
8691
call_return_code,
8792
)
8893

89-
from .benchmark_code_generator import (
90-
BenchmarkCodeGenerator,
91-
ExtCallGenerator,
92-
JumpLoopGenerator,
93-
)
9494
from .code import (
9595
CalldataCase,
9696
Case,

0 commit comments

Comments
 (0)