Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a new CachedMask for BDN #61

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 34 additions & 39 deletions sign/bdn/bdn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ package bdn
import (
"crypto/cipher"
"errors"
"fmt"
"math/big"

"github.com/drand/kyber"
Expand All @@ -31,23 +32,16 @@ var modulus128 = new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 128), big.NewI
// We also use the entire roster so that the coefficient will vary for the same
// public key used in different roster
func hashPointToR(pubs []kyber.Point) ([]kyber.Scalar, error) {
peers := make([][]byte, len(pubs))
for i, pub := range pubs {
peer, err := pub.MarshalBinary()
if err != nil {
return nil, err
}

peers[i] = peer
}

h, err := blake2s.NewXOF(blake2s.OutputLengthUnknown, nil)
if err != nil {
return nil, err
}

for _, peer := range peers {
_, err := h.Write(peer)
for _, pub := range pubs {
peer, err := pub.MarshalBinary()
if err != nil {
return nil, err
}
_, err = h.Write(peer)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -128,32 +122,35 @@ func (scheme *Scheme) Verify(x kyber.Point, msg, sig []byte) error {

// AggregateSignatures aggregates the signatures using a coefficient for each
// one of them where c = H(pk) and H: keyGroup -> R with R = {1, ..., 2^128}
func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber.Point, error) {
if len(sigs) != mask.CountEnabled() {
return nil, errors.New("length of signatures and public keys must match")
}

coefs, err := hashPointToR(mask.Publics())
func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask Mask) (kyber.Point, error) {
bdnMask, err := newCachedMask(mask, false)
if err != nil {
return nil, err
}

agg := scheme.sigGroup.Point()
for i, buf := range sigs {
peerIndex := mask.IndexOfNthEnabled(i)
if peerIndex < 0 {
// this should never happen as we check the lenths at the beginning
// an error here is probably a bug in the mask
return nil, errors.New("couldn't find the index")
for i := range bdnMask.publics {
if enabled, err := mask.GetBit(i); err != nil {
// this should never happen because of the loop boundary
// an error here is probably a bug in the mask implementation
return nil, fmt.Errorf("couldn't find the index %d: %w", i, err)
} else if !enabled {
continue
}

if len(sigs) == 0 {
return nil, errors.New("length of signatures and public keys must match")
}

buf := sigs[0]
sigs = sigs[1:]

sig := scheme.sigGroup.Point()
err = sig.UnmarshalBinary(buf)
if err != nil {
return nil, err
}

sigC := sig.Clone().Mul(coefs[peerIndex], sig)
sigC := sig.Clone().Mul(bdnMask.coefs[i], sig)
// c+1 because R is in the range [1, 2^128] and not [0, 2^128-1]
sigC = sigC.Add(sigC, sig)
agg = agg.Add(agg, sigC)
Expand All @@ -165,25 +162,23 @@ func (scheme *Scheme) AggregateSignatures(sigs [][]byte, mask *sign.Mask) (kyber
// AggregatePublicKeys aggregates a set of public keys (similarly to
// AggregateSignatures for signatures) using the hash function
// H: keyGroup -> R with R = {1, ..., 2^128}.
func (scheme *Scheme) AggregatePublicKeys(mask *sign.Mask) (kyber.Point, error) {
coefs, err := hashPointToR(mask.Publics())
func (scheme *Scheme) AggregatePublicKeys(mask Mask) (kyber.Point, error) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change for anyone abstracting over Scheme using an interface (e.g., for testing). If that's going to be an issue, I can look into alternative APIs (e.g., adding additional methods, putting the methods on the CachedMask, etc.

bdnMask, err := newCachedMask(mask, false)
if err != nil {
return nil, err
}

agg := scheme.keyGroup.Point()
for i := 0; i < mask.CountEnabled(); i++ {
peerIndex := mask.IndexOfNthEnabled(i)
if peerIndex < 0 {
for i := range bdnMask.publics {
if enabled, err := mask.GetBit(i); err != nil {
// this should never happen because of the loop boundary
// an error here is probably a bug in the mask implementation
return nil, errors.New("couldn't find the index")
return nil, fmt.Errorf("couldn't find the index %d: %w", i, err)
} else if !enabled {
continue
}

pub := mask.Publics()[peerIndex]
pubC := pub.Clone().Mul(coefs[peerIndex], pub)
pubC = pubC.Add(pubC, pub)
agg = agg.Add(agg, pubC)
agg = agg.Add(agg, bdnMask.getOrComputePubC(i))
}

return agg, nil
Expand Down Expand Up @@ -217,14 +212,14 @@ func Verify(suite pairing.Suite, x kyber.Point, msg, sig []byte) error {
// AggregateSignatures aggregates the signatures using a coefficient for each
// one of them where c = H(pk) and H: G2 -> R with R = {1, ..., 2^128}
// Deprecated: use the new scheme methods instead.
func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask *sign.Mask) (kyber.Point, error) {
func AggregateSignatures(suite pairing.Suite, sigs [][]byte, mask Mask) (kyber.Point, error) {
return NewSchemeOnG1(suite).AggregateSignatures(sigs, mask)
}

// AggregatePublicKeys aggregates a set of public keys (similarly to
// AggregateSignatures for signatures) using the hash function
// H: G2 -> R with R = {1, ..., 2^128}.
// Deprecated: use the new scheme methods instead.
func AggregatePublicKeys(suite pairing.Suite, mask *sign.Mask) (kyber.Point, error) {
func AggregatePublicKeys(suite pairing.Suite, mask Mask) (kyber.Point, error) {
return NewSchemeOnG1(suite).AggregatePublicKeys(mask)
}
100 changes: 100 additions & 0 deletions sign/bdn/bdn_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package bdn

import (
"encoding"
"encoding/hex"
"fmt"
"testing"

Expand Down Expand Up @@ -183,6 +185,104 @@ func Benchmark_BDN_AggregateSigs(b *testing.B) {
}
}

func Benchmark_BDN_BLS12381_AggregateVerify(b *testing.B) {
suite := bls12381.NewBLS12381Suite()
schemeOnG2 := NewSchemeOnG2(suite)

rng := random.New()
pubKeys := make([]kyber.Point, 3000)
privKeys := make([]kyber.Scalar, 3000)
for i := range pubKeys {
privKeys[i], pubKeys[i] = schemeOnG2.NewKeyPair(rng)
}

baseMask, err := sign.NewMask(suite, pubKeys, nil)
require.NoError(b, err)
mask, err := NewCachedMask(baseMask)
require.NoError(b, err)
for i := range pubKeys {
require.NoError(b, mask.SetBit(i, true))
}

msg := []byte("Hello many times Boneh-Lynn-Shacham")
sigs := make([][]byte, len(privKeys))
for i, k := range privKeys {
s, err := schemeOnG2.Sign(k, msg)
require.NoError(b, err)
sigs[i] = s
}

sig, err := schemeOnG2.AggregateSignatures(sigs, mask)
require.NoError(b, err)
sigb, err := sig.MarshalBinary()
require.NoError(b, err)

b.ResetTimer()
for i := 0; i < b.N; i++ {
pk, err := schemeOnG2.AggregatePublicKeys(mask)
require.NoError(b, err)
require.NoError(b, schemeOnG2.Verify(pk, msg, sigb))
}
}

func unmarshalHex[T encoding.BinaryUnmarshaler](t *testing.T, into T, s string) T {
t.Helper()
b, err := hex.DecodeString(s)
require.NoError(t, err)
require.NoError(t, into.UnmarshalBinary(b))
return into
}

// This tests exists to make sure we don't accidentally make breaking changes to signature
// aggregation by using checking against known aggregated signatures and keys.
func TestBDNFixtures(t *testing.T) {
suite := bn256.NewSuite()
schemeOnG1 := NewSchemeOnG1(suite)

public1 := unmarshalHex(t, suite.G2().Point(), "1a30714035c7a161e286e54c191b8c68345bd8239c74925a26290e8e1ae97ed6657958a17dca12c943fadceb11b824402389ff427179e0f10194da3c1b771c6083797d2b5915ea78123cbdb99ea6389d6d6b67dcb512a2b552c373094ee5693524e3ebb4a176f7efa7285c25c80081d8cb598745978f1a63b886c09a316b1493")
private1 := unmarshalHex(t, suite.G2().Scalar(), "49cfe5e9f4532670137184d43c0299f8b635bcacf6b0af7cab262494602d9f38")
public2 := unmarshalHex(t, suite.G2().Point(), "603bc61466ec8762ec6de2ba9a80b9d302d08f580d1685ac45a8e404a6ed549719dc0faf94d896a9983ff23423772720e3de5d800bc200de6f7d7e146162d3183b8880c5c0d8b71ca4b3b40f30c12d8cc0679c81a47c239c6aa7e9cc2edab4a927fe865cd413c1c17e3df8f74108e784cd77dd3e161bdaf30019a55826a32a1f")
private2 := unmarshalHex(t, suite.G2().Scalar(), "493abea4bb35b74c78ad9245f9d37883aeb6ee91f7fb0d8a8e11abf7aa2be581")
public3 := unmarshalHex(t, suite.G2().Point(), "56118769a1f0b6286abacaa32109c1497ab0819c5d21f27317e184b6681c283007aa981cb4760de044946febdd6503ab77a4586bc29c04159e53a6fa5dcb9c0261ccd1cb2e28db5204ca829ac9f6be95f957a626544adc34ba3bc542533b6e2f5cbd0567e343641a61a42b63f26c3625f74b66f6f46d17b3bf1688fae4d455ec")
private3 := unmarshalHex(t, suite.G2().Scalar(), "7fb0ebc317e161502208c3c16a4af890dedc3c7b275e8a04e99c0528aa6a19aa")

sig1Exp, err := hex.DecodeString("0913b76987be19f943be23b636cab9a2484507717326bd8bbdcdbbb6b8d5eb9253cfb3597c3fa550ee4972a398813650825a871f8e0b242ae5ddbce1b7c0e2a8")
require.NoError(t, err)
sig2Exp, err := hex.DecodeString("21195d29b1863bca1559e24375211d1411d8a28a8f4c772870b07f4ccda2fd5e337c1315c210475c683e3aa8b87d3aed3f7255b3087daa30d1e1432dd61d7484")
require.NoError(t, err)
sig3Exp, err := hex.DecodeString("3c1ac80345c1733630dbdc8106925c867544b521c259f9fa9678d477e6e5d3d212b09bc0d95137c3dbc0af2241415156c56e757d5577a609293584d045593195")
require.NoError(t, err)

aggSigExp := unmarshalHex(t, suite.G1().Point(), "520875e6667e0acf489e458c6c2233d09af81afa3b2045e0ec2435cfc582ba2c44af281d688efcf991d20975ce32c9933a09f8c4b38c18ef4b4510d8fa0f09d7")
aggKeyExp := unmarshalHex(t, suite.G2().Point(), "394d47291878a81fefb17708c57cf8078b24c46bf4554b3012732acd15395dbf09f13a65e068de766f5449d1de130f09bf09dc35a67f7f822f2a187230e155891d40db3c51afa5b3e05a039c50d04ff9c788718a2887e34644a55a14a2a2679226a3315c281e03367a4d797db819625e0c662d35e45e0e9e7604c104179ae8a7")

msg := []byte("Hello many times Boneh-Lynn-Shacham")
sig1, err := schemeOnG1.Sign(private1, msg)
require.Nil(t, err)
require.Equal(t, sig1Exp, sig1)

sig2, err := schemeOnG1.Sign(private2, msg)
require.Nil(t, err)
require.Equal(t, sig2Exp, sig2)

sig3, err := schemeOnG1.Sign(private3, msg)
require.Nil(t, err)
require.Equal(t, sig3Exp, sig3)

mask, _ := sign.NewMask(suite, []kyber.Point{public1, public2, public3}, nil)
mask.SetBit(0, true)
mask.SetBit(1, false)
mask.SetBit(2, true)

aggSig, err := schemeOnG1.AggregateSignatures([][]byte{sig1, sig2, sig3}, mask)
require.NoError(t, err)
require.True(t, aggSigExp.Equal(aggSig))

aggKey, err := schemeOnG1.AggregatePublicKeys(mask)
require.NoError(t, err)
require.True(t, aggKeyExp.Equal(aggKey))
}

func TestBDNDeprecatedAPIs(t *testing.T) {
msg := []byte("Hello Boneh-Lynn-Shacham")
suite := bn256.NewSuite()
Expand Down
112 changes: 112 additions & 0 deletions sign/bdn/mask.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package bdn

import (
"fmt"

"github.com/drand/kyber"
"github.com/drand/kyber/sign"
)

type Mask interface {
GetBit(i int) (bool, error)
SetBit(i int, enable bool) error

IndexOfNthEnabled(nth int) int
NthEnabledAtIndex(idx int) int

Publics() []kyber.Point
Participants() []kyber.Point

CountEnabled() int
CountTotal() int

Len() int
Mask() []byte
SetMask(mask []byte) error
Merge(mask []byte) error
}

var _ Mask = (*sign.Mask)(nil)

// We need to rename this, otherwise we have a public field named Mask (when we embed it) which
// conflicts with the function named Mask. It also makes it private, which is nice.
type maskI = Mask

type CachedMask struct {
maskI
coefs []kyber.Scalar
pubKeyC []kyber.Point
// We could call Mask.Publics() instead of keeping these here, but that function copies the
// slice and this field lets us avoid that copy.
publics []kyber.Point
}

// Convert the passed mask (likely a *sign.Mask) into a BDN-specific mask with pre-computed terms.
//
// This cached mask will:
//
// 1. Pre-compute coefficients for signature aggregation. Once the CachedMask has been instantiated,
// distinct sets of signatures can be aggregated without any BLAKE2S hashing.
// 2. Pre-computes the terms for public key aggregation. Once the CachedMask has been instantiated,
// distinct sets of public keys can be aggregated by simply summing the cached terms, ~2 orders
// of magnitude faster than aggregating from scratch.
func NewCachedMask(mask Mask) (*CachedMask, error) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to bikeshed on the name. Maybe just callit BDNMask?

return newCachedMask(mask, true)
}

func newCachedMask(mask Mask, precomputePubC bool) (*CachedMask, error) {
if m, ok := mask.(*CachedMask); ok {
return m, nil
}

publics := mask.Publics()
coefs, err := hashPointToR(publics)
if err != nil {
return nil, fmt.Errorf("failed to hash public keys: %w", err)
}

cm := &CachedMask{
maskI: mask,
coefs: coefs,
publics: publics,
}

if precomputePubC {
pubKeyC := make([]kyber.Point, len(publics))
for i := range publics {
pubKeyC[i] = cm.getOrComputePubC(i)
}
cm.pubKeyC = pubKeyC
}

return cm, err
}

// Clone copies the BDN mask while keeping the precomputed coefficients, etc.
func (cm *CachedMask) Clone() *CachedMask {
newMask, err := sign.NewMask(nil, cm.publics, nil)
if err != nil {
// Not possible given that we didn't pass our own key.
panic(fmt.Sprintf("failed to create mask: %s", err))
}
if err := newMask.SetMask(cm.Mask()); err != nil {
// Not possible given that we're using the same sized mask.
panic(fmt.Sprintf("failed to create mask: %s", err))
}
return &CachedMask{
maskI: newMask,
coefs: cm.coefs,
pubKeyC: cm.pubKeyC,
publics: cm.publics,
}
}

func (cm *CachedMask) getOrComputePubC(i int) kyber.Point {
if cm.pubKeyC == nil {
// NOTE: don't cache here as we may be sharing this mask between threads.
pub := cm.publics[i]
pubC := pub.Clone().Mul(cm.coefs[i], pub)
return pubC.Add(pubC, pub)
}
return cm.pubKeyC[i]
}
11 changes: 11 additions & 0 deletions sign/mask.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,17 @@ func (m *Mask) SetMask(mask []byte) error {
return nil
}

// GetBit returns true if the given bit is set.
func (m *Mask) GetBit(i int) (bool, error) {
if i >= len(m.publics) || i < 0 {
return false, errors.New("index out of range")
}

byteIndex := i / 8
bitIndex := byte(1) << uint(i&7)
return m.mask[byteIndex]&bitIndex != 0, nil
}

// SetBit turns on or off the bit at the given index.
func (m *Mask) SetBit(i int, enable bool) error {
if i >= len(m.publics) || i < 0 {
Expand Down
Loading