Skip to content

Commit

Permalink
Merge pull request #1 from HerodotusDev/mpt
Browse files Browse the repository at this point in the history
mpt fix
  • Loading branch information
petscheit authored Apr 25, 2024
2 parents b737480 + 8842cb5 commit 629825c
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 40 deletions.
101 changes: 68 additions & 33 deletions lib/mpt.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ from lib.rlp_little import (
assert_subset_in_key,
extract_nibble_from_key,
)
from lib.utils import felt_divmod, felt_divmod_8, word_reverse_endian_64, get_felt_bitlength
from lib.utils import felt_divmod, felt_divmod_8, word_reverse_endian_64, get_felt_bitlength_128

// Verify a Merkle Patricia Tree proof.
// params:
Expand All @@ -37,11 +37,12 @@ func verify_mpt_proof{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr:
pow2_array: felt*,
) -> (value: felt*, value_len: felt) {
alloc_locals;
%{ print(f"\n\nNode index {ids.node_index+1}/{ids.mpt_proof_len}") %}
%{ print(f"\n\nNode index {ids.node_index+1}/{ids.mpt_proof_len} \n \t {ids.n_nibbles_already_checked=}") %}
if (node_index == mpt_proof_len - 1) {
// Last node : item of interest is the value.
// Check that the hash of the last node is the expected one.
// Check that the final accumulated key is the expected one.
// Check the number of bytes in the key is equal to the number of bytes checked in the key.
let (node_hash: Uint256) = keccak(mpt_proof[node_index], mpt_proof_bytes_len[node_index]);
%{ print(f"node_hash : {hex(ids.node_hash.low + 2**128*ids.node_hash.high)}") %}
%{ print(f"hash_to_assert : {hex(ids.hash_to_assert.low + 2**128*ids.hash_to_assert.high)}") %}
Expand All @@ -56,7 +57,37 @@ func verify_mpt_proof{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr:
key_little=key_little,
n_nibbles_already_checked=n_nibbles_already_checked,
);
local key_bits;
with pow2_array {
if (key_little.high != 0) {
let key_bit_high = get_felt_bitlength_128(key_little.high);
assert key_bits = 128 + key_bit_high;
} else {
let key_bit_low = get_felt_bitlength_128(key_little.low);
assert key_bits = key_bit_low;
}
}
local n_bytes_in_key;
let (n_bytes_in_key_tmp, rem) = felt_divmod_8(key_bits);

if (n_bytes_in_key_tmp == 0) {
assert n_bytes_in_key = 1;
} else {
if (rem != 0) {
assert n_bytes_in_key = n_bytes_in_key_tmp + 1;
} else {
assert n_bytes_in_key = n_bytes_in_key_tmp;
}
}

local n_bytes_checked;
let (n_bytes_checked_tmp, rem) = felt_divmod(n_nibbles_checked, 2);
if (rem != 0) {
assert n_bytes_checked = n_bytes_checked_tmp + 1;
} else {
assert n_bytes_checked = n_bytes_checked_tmp;
}
assert n_bytes_in_key = n_bytes_checked;
return (item_of_interest, item_of_interest_len);
} else {
// Not last node : item of interest is the hash of the next node.
Expand Down Expand Up @@ -284,11 +315,11 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
// Ensure first_item_type is either 0 or 1.
assert (first_item_type - 1) * (first_item_type) = 0;

let first_item_prefix = extract_nibble_at_byte_pos(
let first_item_key_prefix = extract_nibble_at_byte_pos(
rlp[0], first_item_start_offset + first_item_type, 0, pow2_array
);
%{
prefix = ids.first_item_prefix
prefix = ids.first_item_key_prefix
if prefix == 0:
print("First item is an extension node, even number of nibbles")
elif prefix == 1:
Expand All @@ -301,10 +332,10 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
raise Exception(f"Unknown prefix {prefix} for MPT node with 2 items")
%}
local odd: felt;
if (first_item_prefix == 0) {
if (first_item_key_prefix == 0) {
assert odd = 0;
} else {
if (first_item_prefix == 2) {
if (first_item_key_prefix == 2) {
assert odd = 0;
} else {
// 1 & 3 case.
Expand All @@ -328,9 +359,13 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
);
%{ print(f"nibbles already checked: {ids.n_nibbles_already_checked}") %}

local range_check_ptr_f;
local bitwise_ptr_f: BitwiseBuiltin*;
local n_nibbles_already_checked_f;
local pow2_array_f: felt*;
if (first_item_type != 0) {
// If the first item is not a single byte, verify subset in key.
assert_subset_in_key(
let (n_nibbles_asserted) = assert_subset_in_key(
key_subset=extracted_key_subset,
key_subset_len=extracted_key_subset_len,
key_subset_nibble_len=n_nibbles_in_first_item,
Expand All @@ -339,36 +374,36 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
cut_nibble=odd,
pow2_array=pow2_array,
);
tempvar range_check_ptr = range_check_ptr;
tempvar bitwise_ptr = bitwise_ptr;
tempvar pow2_array = pow2_array;
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
assert n_nibbles_already_checked_f = n_nibbles_already_checked + n_nibbles_asserted;
assert pow2_array_f = pow2_array;
} else {
// if the first item is a single byte, skip subset verification and assert n_nibbles_already_checked == n_nibbles_in_key
local key_bits;
with pow2_array {
if (key_little.high != 0) {
let key_bit_high = get_felt_bitlength(key_little.high);
assert key_bits = 128 + key_bit_high;
} else {
let key_bit_low = get_felt_bitlength(key_little.low);
assert key_bits = key_bit_low;
}
}
local n_nibbles_in_key: felt; // <=> ceil(key_bits/4)
let (n_nibbles_in_key_tmp, remainder) = felt_divmod(key_bits, 4);
if (remainder != 0) {
assert n_nibbles_in_key = n_nibbles_in_key_tmp + 1;
// if the first item is a single byte

if (odd != 0) {
// If the first item has an odd number of nibbles, since there are two nibbles in one byte, the second nibble needs to be checked
let key_nibble = extract_nibble_from_key(
key_little, n_nibbles_already_checked, pow2_array
);
let (_, first_item_nibble) = felt_divmod(first_item_prefix, 2 ** 4);
assert key_nibble = first_item_nibble;
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
assert n_nibbles_already_checked_f = n_nibbles_already_checked + 1;
assert pow2_array_f = pow2_array;
} else {
assert n_nibbles_in_key = n_nibbles_in_key_tmp;
// If the first item has en even number of nibbles, since there are two nibbles in one byte, there is nothing to check.
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
assert n_nibbles_already_checked_f = n_nibbles_already_checked;
assert pow2_array_f = pow2_array;
}
assert n_nibbles_in_key = n_nibbles_already_checked;
tempvar range_check_ptr = range_check_ptr;
tempvar bitwise_ptr = bitwise_ptr;
tempvar pow2_array = pow2_array;
}
let range_check_ptr = range_check_ptr;
let bitwise_ptr = bitwise_ptr;
let pow2_array = pow2_array;
let range_check_ptr = range_check_ptr_f;
let bitwise_ptr = bitwise_ptr_f;
let pow2_array = pow2_array_f;
let n_nibbles_already_checked = n_nibbles_already_checked_f;

// Extract the hash or value.

Expand Down
33 changes: 26 additions & 7 deletions lib/rlp_little.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ from lib.utils import (
get_0xff_mask,
word_reverse_endian_64,
bitwise_divmod,
get_felt_bitlength_128,
)

// Takes a 64 bit word in little endian, returns the byte at a given position as it would be in big endian.
Expand Down Expand Up @@ -79,10 +80,11 @@ func key_subset_to_uint256(key_subset: felt*, key_subset_len: felt) -> Uint256 {
// params:
// key_subset : array of 64 bit words with little endian bytes, representing a subset of the key
// key_subset_len : length of the subset in number of 64 bit words
// key_subset_bytes_len : length of the subset in number of bytes
// key_subset_bytes_len : length of the subset in number of nibbles
// key subset is of the form [b7 b6 b5 b4 b3 b2 b1 b0, b15 b14 b13 b12 b11 b10 b9 b8, ...]
// key_little : 256 bit key in little endian
// key_little is of the form high = [b63, ..., b32] , low = [b31, ..., b0]
// returns the actual number of nibbles checked from key_subset within the actual key_little
func assert_subset_in_key{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
key_subset: felt*,
key_subset_len: felt,
Expand All @@ -91,7 +93,7 @@ func assert_subset_in_key{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
n_nibbles_already_checked: felt,
cut_nibble: felt,
pow2_array: felt*,
) -> () {
) -> (n_nibbles_checked: felt) {
alloc_locals;
let key_subset_256t = key_subset_to_uint256(key_subset, key_subset_len);
%{ print(f"key_susbet_uncut={hex(ids.key_subset_256t.low + ids.key_subset_256t.high*2**128)}") %}
Expand Down Expand Up @@ -122,8 +124,8 @@ func assert_subset_in_key{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
local key_shifted: Uint256;
local key_shifted_last_nibble: felt;
if (odd_checked_nibbles != 0) {
let (upow) = uint256_pow2(Uint256((n_nibbles_already_checked + 1) * 4, 0));
let (key_shiftedt, rem) = uint256_unsigned_div_rem(key_little, upow);
let (upow) = uint256_pow2(Uint256((n_nibbles_already_checked + 1) * 4, 0)); // p = 2**(n_nib_checked+1)
let (key_shiftedt, rem) = uint256_unsigned_div_rem(key_little, upow); //
let (upow_) = uint256_pow2(Uint256((n_nibbles_already_checked - 1) * 4, 0));
let (byte_u256, _) = uint256_unsigned_div_rem(rem, upow_);
let (_, nibble) = felt_divmod(byte_u256.low, 2 ** 4);
Expand Down Expand Up @@ -156,17 +158,34 @@ func assert_subset_in_key{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
print(f"\t final key high : {hex(ids.key_high)}")
print(f"\t key subset high : {hex(ids.key_subset_256.high)}")
%}

local key_subset_nibbles;
let key_subset_bits = get_felt_bitlength_128{pow2_array=pow2_array}(key_subset_256.high);
let (key_subset_nibbles_tmp, remainder) = felt_divmod(128 + key_subset_bits, 4);
if (remainder != 0) {
assert key_subset_nibbles = key_subset_nibbles_tmp + 1;
} else {
assert key_subset_nibbles = key_subset_nibbles_tmp;
}
assert key_subset_256.low = key_shifted.low;
assert key_subset_256.high = key_high;
assert key_subset_last_nibble = key_shifted_last_nibble;
return ();
return (n_nibbles_checked=key_subset_nibbles + cut_nibble);
} else {
let (_, key_low) = felt_divmod(key_shifted.low, pow2_array[4 * key_subset_nibble_len]);
assert key_subset_256.low = key_low;
assert key_subset_256.high = 0;
assert key_subset_last_nibble = key_shifted_last_nibble;
return ();
local key_subset_nibbles;
let key_subset_bits = get_felt_bitlength_128{pow2_array=pow2_array}(key_subset_256.low);
let (key_subset_nibbles_tmp, remainder) = felt_divmod(key_subset_bits, 4);

if (remainder != 0) {
assert key_subset_nibbles = key_subset_nibbles_tmp + 1;
} else {
assert key_subset_nibbles = key_subset_nibbles_tmp;
}

return (n_nibbles_checked=key_subset_nibbles + cut_nibble);
}
}

Expand Down
34 changes: 34 additions & 0 deletions lib/utils.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,40 @@ func get_felt_bitlength{range_check_ptr, pow2_array: felt*}(x: felt) -> felt {
return bit_length;
}

// Returns the number of bits in x.
// Implicits arguments:
// - pow2_array: felt* - A pointer such that pow2_array[i] = 2^i for i in [0, 128].
// Params:
// - x: felt - Input value.
// Assumptions for the caller:
// - 1 <= x < 2^128
// Returns:
// - bit_length: felt - Number of bits in x.
func get_felt_bitlength_128{range_check_ptr, pow2_array: felt*}(x: felt) -> felt {
alloc_locals;
local bit_length;
%{
x = ids.x
ids.bit_length = x.bit_length()
%}
if (bit_length == 128) {
assert [range_check_ptr] = x - 2 ** 127;
tempvar range_check_ptr = range_check_ptr + 1;
return bit_length;
} else {
// Computes N=2^bit_length and n=2^(bit_length-1)
// x is supposed to verify n = 2^(b-1) <= x < N = 2^bit_length <=> x has bit_length bits
tempvar N = pow2_array[bit_length];
tempvar n = pow2_array[bit_length - 1];
assert [range_check_ptr] = bit_length;
assert [range_check_ptr + 1] = 128 - bit_length;
assert [range_check_ptr + 2] = N - x - 1;
assert [range_check_ptr + 3] = x - n;
tempvar range_check_ptr = range_check_ptr + 4;
return bit_length;
}
}

// Computes x//y and x%y.
// Assumption: y must be a power of 2
// params:
Expand Down

0 comments on commit 629825c

Please sign in to comment.