Skip to content

Commit 6d3b699

Browse files
committed
mssmt: add tree copy functionality for full and compacted trees
This commit introduces a new `Copy` method to both the `FullTree` and `CompactedTree` implementations of the MS-SMT. This method allows copying all key-value pairs from a source tree to a target tree, assuming the target tree is initially empty. The `Copy` method is implemented differently for each tree type: - For `FullTree`, the method recursively traverses the tree, collecting all non-empty leaf nodes along with their keys. It then inserts these leaves into the target tree. - For `CompactedTree`, the method similarly traverses the tree, collecting all non-empty compacted leaf nodes along with their keys. It then inserts these leaves into the target tree. A new test case, `TestTreeCopy`, is added to verify the correctness of the `Copy` method for both tree types, including copying between different tree types (FullTree to CompactedTree and vice versa). The test case generates a set of random leaves, inserts them into a source tree, copies the source tree to a target tree, and then verifies that the target tree contains the same leaves as the source tree.
1 parent 581c881 commit 6d3b699

File tree

4 files changed

+370
-0
lines changed

4 files changed

+370
-0
lines changed

mssmt/compacted_tree.go

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,166 @@ func (t *CompactedTree) MerkleProof(ctx context.Context, key [hashSize]byte) (
392392

393393
return NewProof(proof), nil
394394
}
395+
396+
// collectLeaves recursively traverses the compacted tree represented by the
397+
// given node and collects all non-empty leaf nodes along with their keys.
398+
func collectLeaves(ctx context.Context, tx TreeStoreViewTx, node Node) (
399+
map[[hashSize]byte]*LeafNode, error) {
400+
401+
// Base case: If it's a compacted leaf node.
402+
if compactedLeaf, ok := node.(*CompactedLeafNode); ok {
403+
// We only care about non-empty leaves.
404+
if compactedLeaf.LeafNode.IsEmpty() {
405+
return make(map[[hashSize]byte]*LeafNode), nil
406+
}
407+
408+
return map[[hashSize]byte]*LeafNode{
409+
compactedLeaf.Key(): compactedLeaf.LeafNode,
410+
}, nil
411+
}
412+
413+
// Recursive step: If it's a branch node.
414+
if _, ok := node.(*BranchNode); ok {
415+
return collectLeavesRecursive(ctx, tx, node, 0)
416+
}
417+
418+
// Handle unexpected node types (like ComputedNode if store returns
419+
// them). If it's an empty leaf node implicitly (e.g., EmptyTree node),
420+
// return empty.
421+
if IsEqualNode(node, EmptyLeafNode) {
422+
return make(map[[hashSize]byte]*LeafNode), nil
423+
}
424+
425+
// Check against EmptyTree branches requires depth. If we encounter an
426+
// unexpected node that isn't identifiable as empty, return error.
427+
return nil, fmt.Errorf("unexpected node type %T encountered "+
428+
"during leaf collection", node)
429+
}
430+
431+
// collectLeavesRecursive is the helper for collectLeaves that includes depth.
432+
func collectLeavesRecursive(ctx context.Context, tx TreeStoreViewTx, node Node,
433+
depth int) (map[[hashSize]byte]*LeafNode, error) {
434+
435+
// Base case: If it's a compacted leaf node.
436+
if compactedLeaf, ok := node.(*CompactedLeafNode); ok {
437+
if compactedLeaf.LeafNode.IsEmpty() {
438+
return make(map[[hashSize]byte]*LeafNode), nil
439+
}
440+
return map[[hashSize]byte]*LeafNode{
441+
compactedLeaf.Key(): compactedLeaf.LeafNode,
442+
}, nil
443+
}
444+
445+
// Recursive step: If it's a branch node.
446+
if branchNode, ok := node.(*BranchNode); ok {
447+
// Optimization: if the branch is empty, return early.
448+
if depth < MaxTreeLevels &&
449+
IsEqualNode(branchNode, EmptyTree[depth]) {
450+
451+
return make(map[[hashSize]byte]*LeafNode), nil
452+
}
453+
454+
// Handle case where depth might exceed EmptyTree bounds if
455+
// logic error exists
456+
if depth >= MaxTreeLevels {
457+
// This shouldn't happen if called correctly, implies a
458+
// leaf.
459+
return nil, fmt.Errorf("invalid depth %d for branch "+
460+
"node", depth)
461+
}
462+
463+
left, right, err := tx.GetChildren(depth, branchNode.NodeHash())
464+
if err != nil {
465+
// If children not found, it might be an empty branch
466+
// implicitly Check if the error indicates "not found"
467+
// or similar Depending on store impl, this might be how
468+
// empty is signaled For now, treat error as fatal.
469+
return nil, fmt.Errorf("error getting children for "+
470+
"branch %s at depth %d: %w",
471+
branchNode.NodeHash(), depth, err)
472+
}
473+
474+
leftLeaves, err := collectLeavesRecursive(
475+
ctx, tx, left, depth+1,
476+
)
477+
if err != nil {
478+
return nil, err
479+
}
480+
481+
rightLeaves, err := collectLeavesRecursive(
482+
ctx, tx, right, depth+1,
483+
)
484+
if err != nil {
485+
return nil, err
486+
}
487+
488+
// Merge the results.
489+
for k, v := range rightLeaves {
490+
// Check for duplicate keys, although this shouldn't
491+
// happen in a valid SMT.
492+
if _, exists := leftLeaves[k]; exists {
493+
return nil, fmt.Errorf("duplicate key %x found "+
494+
"during leaf collection", k)
495+
}
496+
leftLeaves[k] = v
497+
}
498+
499+
return leftLeaves, nil
500+
}
501+
502+
// Handle unexpected node types or implicit empty nodes. If node is nil
503+
// or explicitly an EmptyLeafNode representation
504+
if node == nil || IsEqualNode(node, EmptyLeafNode) {
505+
return make(map[[hashSize]byte]*LeafNode), nil
506+
}
507+
508+
// Check against EmptyTree branches if possible (requires depth)
509+
if depth < MaxTreeLevels && IsEqualNode(node, EmptyTree[depth]) {
510+
return make(map[[hashSize]byte]*LeafNode), nil
511+
}
512+
513+
return nil, fmt.Errorf("unexpected node type %T encountered "+
514+
"during leaf collection at depth %d", node, depth)
515+
}
516+
517+
// Copy copies all the key-value pairs from the source tree into the target
518+
// tree. The target tree is assumed to be empty.
519+
func (t *CompactedTree) Copy(ctx context.Context, targetTree Tree) error {
520+
var leaves map[[hashSize]byte]*LeafNode
521+
err := t.store.View(ctx, func(tx TreeStoreViewTx) error {
522+
root, err := tx.RootNode()
523+
if err != nil {
524+
return fmt.Errorf("error getting root node: %w", err)
525+
}
526+
527+
// Optimization: If the source tree is empty, there's nothing to
528+
// copy.
529+
if IsEqualNode(root, EmptyTree[0]) {
530+
leaves = make(map[[hashSize]byte]*LeafNode)
531+
return nil
532+
}
533+
534+
// Start recursive collection from the root at depth 0.
535+
leaves, err = collectLeavesRecursive(ctx, tx, root, 0)
536+
if err != nil {
537+
return fmt.Errorf("error collecting leaves: %w", err)
538+
}
539+
540+
return nil
541+
})
542+
if err != nil {
543+
return err
544+
}
545+
546+
// Insert all found leaves into the target tree.
547+
for key, leaf := range leaves {
548+
// Use the target tree's Insert method.
549+
_, err := targetTree.Insert(ctx, key, leaf)
550+
if err != nil {
551+
return fmt.Errorf("error inserting leaf with key %x "+
552+
"into target tree: %w", key, err)
553+
}
554+
}
555+
556+
return nil
557+
}

mssmt/interface.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,8 @@ type Tree interface {
3030
// proof. This is noted by the returned `Proof` containing an empty
3131
// leaf.
3232
MerkleProof(ctx context.Context, key [hashSize]byte) (*Proof, error)
33+
34+
// Copy copies all the key-value pairs from the source tree into the
35+
// target tree. The target tree is assumed to be empty.
36+
Copy(ctx context.Context, targetTree Tree) error
3337
}

mssmt/tree.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,112 @@ func (t *FullTree) MerkleProof(ctx context.Context, key [hashSize]byte) (
333333
return NewProof(proof), nil
334334
}
335335

336+
// findLeaves recursively traverses the tree represented by the given node and
337+
// collects all non-empty leaf nodes along with their reconstructed keys.
338+
func findLeaves(ctx context.Context, tx TreeStoreViewTx, node Node,
339+
keyPrefix [hashSize]byte, depth int) (map[[hashSize]byte]*LeafNode, error) {
340+
341+
// Base case: If it's a leaf node.
342+
if leafNode, ok := node.(*LeafNode); ok {
343+
if leafNode.IsEmpty() {
344+
return make(map[[hashSize]byte]*LeafNode), nil
345+
}
346+
return map[[hashSize]byte]*LeafNode{keyPrefix: leafNode}, nil
347+
}
348+
349+
// Recursive step: If it's a branch node.
350+
if branchNode, ok := node.(*BranchNode); ok {
351+
// Optimization: if the branch is empty, return early.
352+
if IsEqualNode(branchNode, EmptyTree[depth]) {
353+
return make(map[[hashSize]byte]*LeafNode), nil
354+
}
355+
356+
left, right, err := tx.GetChildren(depth, branchNode.NodeHash())
357+
if err != nil {
358+
return nil, fmt.Errorf("error getting children for "+
359+
"branch %s at depth %d: %w",
360+
branchNode.NodeHash(), depth, err)
361+
}
362+
363+
// Recursively find leaves in the left subtree. The key prefix
364+
// remains the same as the 0 bit is implicitly handled by the
365+
// initial keyPrefix state.
366+
leftLeaves, err := findLeaves(
367+
ctx, tx, left, keyPrefix, depth+1,
368+
)
369+
if err != nil {
370+
return nil, err
371+
}
372+
373+
// Recursively find leaves in the right subtree. We need to set
374+
// the bit corresponding to the current depth in the key prefix.
375+
rightKeyPrefix := keyPrefix
376+
byteIndex := depth / 8
377+
bitIndex := depth % 8
378+
rightKeyPrefix[byteIndex] |= (1 << bitIndex)
379+
380+
rightLeaves, err := findLeaves(
381+
ctx, tx, right, rightKeyPrefix, depth+1,
382+
)
383+
if err != nil {
384+
return nil, err
385+
}
386+
387+
// Merge the results.
388+
for k, v := range rightLeaves {
389+
leftLeaves[k] = v
390+
}
391+
return leftLeaves, nil
392+
}
393+
394+
// Handle unexpected node types.
395+
return nil, fmt.Errorf("unexpected node type %T encountered "+
396+
"during leaf collection", node)
397+
}
398+
399+
// Copy copies all the key-value pairs from the source tree into the target
400+
// tree. The target tree is assumed to be empty.
401+
func (t *FullTree) Copy(ctx context.Context, targetTree Tree) error {
402+
var leaves map[[hashSize]byte]*LeafNode
403+
err := t.store.View(ctx, func(tx TreeStoreViewTx) error {
404+
root, err := tx.RootNode()
405+
if err != nil {
406+
return fmt.Errorf("error getting root node: %w", err)
407+
}
408+
409+
// Optimization: If the source tree is empty, there's nothing
410+
// to copy.
411+
if IsEqualNode(root, EmptyTree[0]) {
412+
leaves = make(map[[hashSize]byte]*LeafNode)
413+
return nil
414+
}
415+
416+
leaves, err = findLeaves(ctx, tx, root, [hashSize]byte{}, 0)
417+
if err != nil {
418+
return fmt.Errorf("error finding leaves: %w", err)
419+
}
420+
return nil
421+
})
422+
if err != nil {
423+
return err
424+
}
425+
426+
// Insert all found leaves into the target tree. We assume the target
427+
// tree handles batching or individual inserts efficiently.
428+
for key, leaf := range leaves {
429+
// Use the target tree's Insert method. We ignore the returned
430+
// tree as we are modifying the targetTree in place via its
431+
// store.
432+
_, err := targetTree.Insert(ctx, key, leaf)
433+
if err != nil {
434+
return fmt.Errorf("error inserting leaf with key %x "+
435+
"into target tree: %w", key, err)
436+
}
437+
}
438+
439+
return nil
440+
}
441+
336442
// VerifyMerkleProof determines whether a merkle proof for the leaf found at the
337443
// given key is valid.
338444
func VerifyMerkleProof(key [hashSize]byte, leaf *LeafNode, proof *Proof,

mssmt/tree_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,103 @@ func TestBIPTestVectors(t *testing.T) {
822822
}
823823
}
824824

825+
// TestTreeCopy tests the Copy method for both FullTree and CompactedTree,
826+
// including copying between different tree types.
827+
func TestTreeCopy(t *testing.T) {
828+
t.Parallel()
829+
830+
leaves := randTree(50) // Use a smaller number for faster testing
831+
832+
// Prepare source trees (Full and Compacted)
833+
ctx := context.Background()
834+
sourceFullStore := mssmt.NewDefaultStore()
835+
sourceFullTree := mssmt.NewFullTree(sourceFullStore)
836+
sourceCompactedStore := mssmt.NewDefaultStore()
837+
sourceCompactedTree := mssmt.NewCompactedTree(sourceCompactedStore)
838+
839+
for _, item := range leaves {
840+
_, err := sourceFullTree.Insert(ctx, item.key, item.leaf)
841+
require.NoError(t, err)
842+
_, err = sourceCompactedTree.Insert(ctx, item.key, item.leaf)
843+
require.NoError(t, err)
844+
}
845+
846+
sourceFullRoot, err := sourceFullTree.Root(ctx)
847+
require.NoError(t, err)
848+
sourceCompactedRoot, err := sourceCompactedTree.Root(ctx)
849+
require.NoError(t, err)
850+
require.True(t, mssmt.IsEqualNode(sourceFullRoot, sourceCompactedRoot))
851+
852+
// Define test cases
853+
testCases := []struct {
854+
name string
855+
sourceTree mssmt.Tree
856+
makeTarget func() mssmt.Tree
857+
}{
858+
{
859+
name: "Full -> Full",
860+
sourceTree: sourceFullTree,
861+
makeTarget: func() mssmt.Tree {
862+
return mssmt.NewFullTree(mssmt.NewDefaultStore())
863+
},
864+
},
865+
{
866+
name: "Full -> Compacted",
867+
sourceTree: sourceFullTree,
868+
makeTarget: func() mssmt.Tree {
869+
return mssmt.NewCompactedTree(mssmt.NewDefaultStore())
870+
},
871+
},
872+
{
873+
name: "Compacted -> Full",
874+
sourceTree: sourceCompactedTree,
875+
makeTarget: func() mssmt.Tree {
876+
return mssmt.NewFullTree(mssmt.NewDefaultStore())
877+
},
878+
},
879+
{
880+
name: "Compacted -> Compacted",
881+
sourceTree: sourceCompactedTree,
882+
makeTarget: func() mssmt.Tree {
883+
return mssmt.NewCompactedTree(mssmt.NewDefaultStore())
884+
},
885+
},
886+
}
887+
888+
for _, tc := range testCases {
889+
tc := tc
890+
t.Run(tc.name, func(t *testing.T) {
891+
t.Parallel()
892+
893+
targetTree := tc.makeTarget()
894+
895+
// Perform the copy
896+
err := tc.sourceTree.Copy(ctx, targetTree)
897+
require.NoError(t, err)
898+
899+
// Verify the target tree root
900+
targetRoot, err := targetTree.Root(ctx)
901+
require.NoError(t, err)
902+
require.True(t, mssmt.IsEqualNode(sourceFullRoot, targetRoot),
903+
"Root mismatch after copy")
904+
905+
// Verify individual leaves in the target tree
906+
for _, item := range leaves {
907+
targetLeaf, err := targetTree.Get(ctx, item.key)
908+
require.NoError(t, err)
909+
require.Equal(t, item.leaf, targetLeaf,
910+
"Leaf mismatch for key %x", item.key)
911+
}
912+
913+
// Verify a non-existent key is still empty
914+
emptyLeaf, err := targetTree.Get(ctx, test.RandHash())
915+
require.NoError(t, err)
916+
require.True(t, emptyLeaf.IsEmpty(), "Non-existent key found")
917+
})
918+
}
919+
}
920+
921+
825922
// runBIPTestVector runs the tests in a single BIP test vector file.
826923
func runBIPTestVector(t *testing.T, testVectors *mssmt.TestVectors) {
827924
for _, validCase := range testVectors.ValidTestCases {

0 commit comments

Comments
 (0)