Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/lean_spec/subspecs/ssz/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from lean_spec.types.byte_arrays import Bytes32

BYTES_PER_CHUNK: int = 32
"""The number of bytes in a Merkle tree chunk."""
"""Number of bytes per Merkle chunk."""

BITS_PER_BYTE: int = 8
"""Number of bits per byte."""

ZERO_HASH: Bytes32 = Bytes32(b"\x00" * BYTES_PER_CHUNK)
"""A zero hash, used for padding in the Merkle tree."""
"""A zero hash, used for padding in Merkleization."""
158 changes: 158 additions & 0 deletions src/lean_spec/subspecs/ssz/hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""
SSZ Merkleization entry point (`hash_tree_root`).

This module exposes:
- A `hash_tree_root(value: object) -> Bytes32` singledispatch function.
- A tiny facade `HashTreeRoot.compute(value)` if you prefer a class entrypoint.
"""

from __future__ import annotations

from functools import singledispatch
from math import ceil
from typing import Final, Type

from lean_spec.subspecs.ssz.constants import BYTES_PER_CHUNK
from lean_spec.types.bitfields import Bitlist, Bitvector
from lean_spec.types.boolean import Boolean
from lean_spec.types.byte_arrays import ByteListBase, Bytes32, ByteVectorBase
from lean_spec.types.collections import (
List,
Vector,
)
from lean_spec.types.container import Container
from lean_spec.types.uint import BaseUint
from lean_spec.types.union import Union

from .merkleization import Merkle
from .pack import Packer


@singledispatch
def hash_tree_root(value: object) -> Bytes32:
"""
Compute `hash_tree_root(value)` for SSZ values.

Concrete specializations are registered below with `@hash_tree_root.register(Type)`.

Raises:
TypeError: If `value` has no registered specialization.
"""
raise TypeError(f"hash_tree_root: unsupported value type {type(value).__name__}")


class HashTreeRoot:
"""OO facade around `hash_tree_root`."""

@staticmethod
def compute(value: object) -> Bytes32:
"""Delegate to the singledispatch implementation."""
return hash_tree_root(value)


@hash_tree_root.register
def _htr_uint(value: BaseUint) -> Bytes32:
"""Basic scalars merkleize as `merkleize(pack(bytes))`."""
return Merkle.merkleize(Packer.pack_bytes(value.encode_bytes()))


@hash_tree_root.register
def _htr_boolean(value: Boolean) -> Bytes32:
return Merkle.merkleize(Packer.pack_bytes(value.encode_bytes()))


@hash_tree_root.register
def _htr_bytes(value: bytes) -> Bytes32:
"""Treat raw bytes like ByteVector[N]."""
return Merkle.merkleize(Packer.pack_bytes(value))


@hash_tree_root.register
def _htr_bytearray(value: bytearray) -> Bytes32:
return Merkle.merkleize(Packer.pack_bytes(bytes(value)))


@hash_tree_root.register
def _htr_memoryview(value: memoryview) -> Bytes32:
data: Final[bytes] = value.tobytes()
return Merkle.merkleize(Packer.pack_bytes(data))


@hash_tree_root.register
def _htr_bytevector(value: ByteVectorBase) -> Bytes32:
return Merkle.merkleize(Packer.pack_bytes(value.encode_bytes()))


@hash_tree_root.register
def _htr_bytelist(value: ByteListBase) -> Bytes32:
data = value.encode_bytes()
limit_chunks = ceil(type(value).LIMIT / BYTES_PER_CHUNK)
root = Merkle.merkleize(Packer.pack_bytes(data), limit=limit_chunks)
return Merkle.mix_in_length(root, len(data))


@hash_tree_root.register
def _htr_bitvector(value: Bitvector) -> Bytes32:
nbits = type(value).LENGTH
limit = (nbits + 255) // 256
chunks = Packer.pack_bits(tuple(bool(b) for b in value))
return Merkle.merkleize(chunks, limit=limit)


@hash_tree_root.register
def _htr_bitlist(value: Bitlist) -> Bytes32:
limit = (type(value).LIMIT + 255) // 256
chunks = Packer.pack_bits(tuple(bool(b) for b in value))
root = Merkle.merkleize(chunks, limit=limit)
return Merkle.mix_in_length(root, len(value))


@hash_tree_root.register
def _htr_vector(value: Vector) -> Bytes32:
elem_t: Type[object] = type(value).ELEMENT_TYPE
length: int = type(value).LENGTH

# BASIC elements (uint/boolean): pack serialized bytes
if issubclass(elem_t, (BaseUint, Boolean)):
elem_size = elem_t.get_byte_length() if issubclass(elem_t, BaseUint) else 1
concat = b"".join(e.encode_bytes() for e in value)
limit_chunks = (length * elem_size + (BYTES_PER_CHUNK - 1)) // BYTES_PER_CHUNK
return Merkle.merkleize(Packer.pack_bytes(concat), limit=limit_chunks)

# COMPOSITE elements: merkleize child roots with limit = length
leaves = [hash_tree_root(e) for e in value]
return Merkle.merkleize(leaves, limit=length)


@hash_tree_root.register
def _htr_list(value: List) -> Bytes32:
elem_t: Type[object] = type(value).ELEMENT_TYPE
limit: int = type(value).LIMIT

# BASIC elements
if issubclass(elem_t, (BaseUint, Boolean)):
elem_size = elem_t.get_byte_length() if issubclass(elem_t, BaseUint) else 1
concat = b"".join(e.encode_bytes() for e in value)
limit_chunks = (limit * elem_size + (BYTES_PER_CHUNK - 1)) // BYTES_PER_CHUNK
root = Merkle.merkleize(Packer.pack_bytes(concat), limit=limit_chunks)
return Merkle.mix_in_length(root, len(value))

# COMPOSITE elements
leaves = [hash_tree_root(e) for e in value]
root = Merkle.merkleize(leaves, limit=limit)
return Merkle.mix_in_length(root, len(value))


@hash_tree_root.register
def _htr_container(value: Container) -> Bytes32:
# Preserve declared field order from the Pydantic model.
leaves = [hash_tree_root(getattr(value, fname)) for fname in type(value).model_fields.keys()]
return Merkle.merkleize(leaves)


@hash_tree_root.register
def _htr_union(value: Union) -> Bytes32:
sel = value.selector()
if value.selected_type() is None:
return Merkle.mix_in_selector(Bytes32(b"\x00" * 32), 0)
return Merkle.mix_in_selector(hash_tree_root(value.value()), sel)
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from lean_spec.types.byte_arrays import Bytes32

from ..constants import ZERO_HASH
from ..gindex import GeneralizedIndex
from ..utils import hash_nodes
from .gindex import GeneralizedIndex

Root = Bytes32
"""The type of a Merkle tree root."""
Expand Down
118 changes: 118 additions & 0 deletions src/lean_spec/subspecs/ssz/merkleization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Merkleization utilities per SSZ."""

from __future__ import annotations

from typing import List, Optional, Sequence

from lean_spec.subspecs.ssz.constants import ZERO_HASH
from lean_spec.subspecs.ssz.utils import get_power_of_two_ceil, hash_nodes
from lean_spec.types.byte_arrays import Bytes32


class Merkle:
"""Static Merkle helpers for SSZ."""

@staticmethod
def merkleize(chunks: Sequence[Bytes32], limit: Optional[int] = None) -> Bytes32:
"""Compute the Merkle root of `chunks`.

Behavior
--------
- If `limit` is None: pad to next power of two of len(chunks).
- If `limit` is provided and >= len(chunks): pad to next power of two of `limit`.
- If `limit` < len(chunks): raise (exceeds limit).
- If no chunks: return ZERO_HASH.
*Exception when `limit` is provided:* return the zero-subtree root for the padded width.

This matches the SSZ spec's padding/limiting rules.
"""
n = len(chunks)
if n == 0:
# If a limit is provided, the tree width is determined by that limit,
# and the root must be the zero-subtree root of that width.
if limit is not None:
width = get_power_of_two_ceil(limit)
return Merkle._zero_tree_root(width)
return ZERO_HASH

# Determine the width of the bottom layer after padding/limiting.
if limit is None:
width = get_power_of_two_ceil(n)
else:
if limit < n:
raise ValueError("merkleize: input exceeds limit")
width = get_power_of_two_ceil(limit)

# Width of 1: the single chunk is the root.
if width == 1:
return chunks[0]

# Start with the leaf layer: provided chunks + ZERO padding.
level: List[Bytes32] = list(chunks) + [ZERO_HASH] * (width - n)

# Reduce bottom-up: pairwise hash until a single root remains.
while len(level) > 1:
nxt: List[Bytes32] = []
it = iter(level)
for a in it:
b = next(it, ZERO_HASH) # Safe: even-length implied by padding
nxt.append(hash_nodes(a, b))
level = nxt
return level[0]

@staticmethod
def merkleize_progressive(chunks: Sequence[Bytes32], num_leaves: int = 1) -> Bytes32:
"""Progressive Merkleization (per spec).

Rare in practice; provided for completeness. Splits on `num_leaves`:
- right: merkleize the first up-to-`num_leaves` chunks using a fixed-width tree
- left: recurse on the remaining chunks, quadrupling the right's width at each step
"""
if len(chunks) == 0:
return ZERO_HASH

# Right branch: fixed-width merkleization of the first `num_leaves` chunks.
right = Merkle.merkleize(chunks[:num_leaves], num_leaves)

# Left branch: recursively collapse everything beyond `num_leaves`.
left = (
Merkle.merkleize_progressive(chunks[num_leaves:], num_leaves * 4)
if len(chunks) > num_leaves
else ZERO_HASH
)

# Combine branches.
return hash_nodes(left, right)

@staticmethod
def mix_in_length(root: Bytes32, length: int) -> Bytes32:
"""Mix the length (as uint256 little-endian) into a Merkle root."""
if length < 0:
raise ValueError("length must be non-negative")
# The "mix" is `hash(root + length_uint256_le)`.
le = length.to_bytes(32, "little")
return hash_nodes(root, Bytes32(le))

@staticmethod
def mix_in_selector(root: Bytes32, selector: int) -> Bytes32:
"""Mix the union selector (as uint256 little-endian) into a Merkle root."""
if selector < 0:
raise ValueError("selector must be non-negative")
le = selector.to_bytes(32, "little")
return hash_nodes(root, Bytes32(le))

@staticmethod
def _zero_tree_root(width_pow2: int) -> Bytes32:
"""
Return the Merkle root of a full zero tree with `width_pow2` leaves.

Power of two >= 1.
"""
if width_pow2 <= 1:
return ZERO_HASH
h = ZERO_HASH
w = width_pow2
while w > 1:
h = hash_nodes(h, h)
w //= 2
return h
96 changes: 96 additions & 0 deletions src/lean_spec/subspecs/ssz/pack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Packing helpers for SSZ Merkleization.

These helpers convert existing *serialized* data into 32-byte chunks (Bytes32).
They do not serialize objects themselves; they only arrange bytes into chunks
as required by the SSZ Merkleization rules.

Design notes
------------
- We keep these helpers in a dedicated class (`Packer`) to make call sites explicit
and discoverable (e.g., `Packer.pack_bytes(...)`), while remaining purely static.
- All functions return `list[Bytes32]`, the canonical chunk form fed into `merkleize`.
"""

from __future__ import annotations

from typing import Iterable, List, Sequence

from lean_spec.subspecs.ssz.constants import BITS_PER_BYTE, BYTES_PER_CHUNK
from lean_spec.types.byte_arrays import Bytes32


class Packer:
"""Collection of static helpers to pack byte data into 32-byte chunks."""

@staticmethod
def _right_pad_to_chunk(b: bytes) -> bytes:
"""Right-pad `b` with zeros up to a multiple of BYTES_PER_CHUNK.

SSZ Merkleization packs serialized basic values into 32-byte "chunks".
When `b` is not already chunk-aligned, we append zero bytes.
"""
# Already aligned? Return as-is.
if len(b) % BYTES_PER_CHUNK == 0:
return b
# Compute the minimal pad size to reach the next multiple of 32.
pad = BYTES_PER_CHUNK - (len(b) % BYTES_PER_CHUNK)
return b + b"\x00" * pad

@staticmethod
def _partition_chunks(b: bytes) -> List[Bytes32]:
"""Partition an already-aligned byte-string into 32-byte chunks.

Precondition: `len(b)` must be a multiple of 32.
"""
if len(b) == 0:
return []
if len(b) % BYTES_PER_CHUNK != 0:
raise ValueError("partition requires a multiple of BYTES_PER_CHUNK")
# Slice in steps of 32 to build Bytes32 chunks.
return [Bytes32(b[i : i + BYTES_PER_CHUNK]) for i in range(0, len(b), BYTES_PER_CHUNK)]

@staticmethod
def pack_basic_serialized(serialized_basic_values: Iterable[bytes]) -> List[Bytes32]:
"""Pack *serialized* basic values (e.g. uintN/boolean/byte) into chunks.

Parameters
----------
serialized_basic_values:
Iterable of bytes objects; each element is already the SSZ-serialized
form of a basic value.

Returns:
-------
list[Bytes32]
Concatenated and right-padded chunks ready for Merkleization.
"""
# Concatenate the serialized representations of individual basic values.
joined = b"".join(serialized_basic_values)
# Right-pad, then partition into 32-byte slices.
return Packer._partition_chunks(Packer._right_pad_to_chunk(joined))

@staticmethod
def pack_bytes(data: bytes) -> List[Bytes32]:
"""Pack raw bytes (e.g. ByteVector/ByteList content) into 32-byte chunks."""
return Packer._partition_chunks(Packer._right_pad_to_chunk(data))

@staticmethod
def pack_bits(bools: Sequence[bool]) -> List[Bytes32]:
"""Pack a boolean sequence into a bitfield, then into 32-byte chunks.

Notes:
-----
- This does **not** add the Bitlist length-delimiter bit. Callers implementing
Bitlist should add it separately or mix the list length at the Merkle level.
- Bit ordering follows SSZ (little-endian within each byte).
"""
if not bools:
return []
# Pack 8 bools per byte (round up).
byte_len = (len(bools) + (BITS_PER_BYTE - 1)) // BITS_PER_BYTE
arr = bytearray(byte_len)
for i, bit in enumerate(bools):
if bit:
# Set the (i % 8)-th bit of the (i // 8)-th byte.
arr[i // BITS_PER_BYTE] |= 1 << (i % BITS_PER_BYTE)
return Packer._partition_chunks(Packer._right_pad_to_chunk(bytes(arr)))
Loading
Loading