diff --git a/.gitignore b/.gitignore index dfb5825..51ea577 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ vendor/ # Go workspace file go.work go.work.sum +.gocache # env file .env diff --git a/hash/poseidon2_goldilocks_plonky2/poseidon2.go b/hash/poseidon2_goldilocks_plonky2/poseidon2.go index 33302a9..080b625 100644 --- a/hash/poseidon2_goldilocks_plonky2/poseidon2.go +++ b/hash/poseidon2_goldilocks_plonky2/poseidon2.go @@ -1,6 +1,7 @@ package poseidon2_plonky2 import ( + "encoding/binary" "fmt" "hash" @@ -76,6 +77,11 @@ func HashNToHashNoPad(input []g.GoldilocksField) HashOut { } func HashNToMNoPad(input []g.GoldilocksField, numOutputs int) []g.GoldilocksField { + for i := 0; i < len(input); i++ { + if uint64(input[i]) >= g.ORDER { + panic("input contains non-canonical field element (value >= ORDER)") + } + } var perm [WIDTH]g.GoldilocksField for i := 0; i < len(input); i += RATE { for j := 0; j < RATE && i+j < len(input); j++ { @@ -96,18 +102,73 @@ func HashNToMNoPad(input []g.GoldilocksField, numOutputs int) []g.GoldilocksFiel } } +// we want to be able to hash arbitrary messages to a digest. +// we parse the input as follows: +// +// pad the input with 0s to a multiple of 7 +// for each 7 bytes; convert them into a goldilocks field element +// +// this creates len(input)/7 field elements +// +// we use Poseidon hash function to absort it, in each permutation, we absorb +// +// RATE = 8 +// +// field elements. +// +// Once we absorbed all field elements, return the first numOutputs field elements. +// If numOutputs > RATE, keep permute and append the output. +// +// Note if we convert 8 bytes into a field element, then in current setup, +// an attacker may find a collision for two different messages m1, m2 where +// +// u64(m1[:8]) = u64(m2[:8]) + Goldilocks::Order +func HashNToMPadBytes(input []byte, numOutputs int) []g.GoldilocksField { + absorbLen := g.Bytes - 1 + if len(input) == 0 { + return HashNToMNoPad(nil, numOutputs) + } + + chunkCount := (len(input) + absorbLen - 1) / absorbLen + fields := make([]g.GoldilocksField, chunkCount) + for i := 0; i < chunkCount; i++ { + start := i * absorbLen + end := start + absorbLen + if end > len(input) { + end = len(input) + } + + var paddedChunk [g.Bytes]byte + copy(paddedChunk[:], input[start:end]) + fields[i] = g.FromCanonicalLittleEndianBytesF(paddedChunk[:]) + } + + return HashNToMNoPad(fields, numOutputs) +} + +// Backward-compatible wrapper: identical to HashNToMPadBytes. func HashNToMNoPadBytes(input []byte, numOutputs int) []g.GoldilocksField { + return HashNToMPadBytes(input, numOutputs) +} + +// Hash bytes to field elements. Assumes every 8 bytes maps to a canonical field element. +// Panics if the field element is non-canonical. +func HashNToMCanonicalBytes(input []byte, numOutputs int) []g.GoldilocksField { + if len(input)%g.Bytes != 0 { panic("input length should be multiple of 8") } - inputLen := len(input) / g.Bytes var perm [WIDTH]g.GoldilocksField for i := 0; i < inputLen; i += RATE { for j := 0; j < RATE && i+j < inputLen; j++ { index := (i + j) * g.Bytes - perm[j] = g.FromCanonicalLittleEndianBytesF(input[index : index+g.Bytes]) + elem := binary.LittleEndian.Uint64(input[index : index+g.Bytes]) + if elem >= g.ORDER { + panic("input contains non-canonical field element (value >= ORDER)") + } + perm[j] = g.GoldilocksField(elem) } Permute(&perm) } diff --git a/hash/poseidon2_goldilocks_plonky2/poseidon2_test.go b/hash/poseidon2_goldilocks_plonky2/poseidon2_test.go index 6da36f9..eefad4c 100644 --- a/hash/poseidon2_goldilocks_plonky2/poseidon2_test.go +++ b/hash/poseidon2_goldilocks_plonky2/poseidon2_test.go @@ -2,6 +2,7 @@ package poseidon2_plonky2 import ( "bytes" + "encoding/binary" "math" "testing" @@ -112,15 +113,13 @@ func TestDigest(t *testing.T) { inputs[1][6] = 1 inputs[1][7] = 0 - g1 := g.FromCanonicalLittleEndianBytesF(inputs[0]) // 289077004332300282 - g2 := g.FromCanonicalLittleEndianBytesF(inputs[1]) // 289644378102298614 - hFunc.Write(inputs[0]) hFunc.Write(inputs[1]) hash := hFunc.Sum(nil) - hash2Elems := HashNoPad([]g.GoldilocksField{g1, g2}) + messageBytes := append(append([]byte{}, inputs[0]...), inputs[1]...) + hash2Elems := HashNToHashNoPad(bytesToFieldElements(messageBytes)) hash2 := hash2Elems.ToLittleEndianBytes() if !bytes.Equal(hash, hash2) { @@ -173,27 +172,19 @@ func TestHashNToHashNoPad(t *testing.T) { } func TestHashNToHashNoPadLarge(t *testing.T) { - res := HashNToHashNoPad([]g.GoldilocksField{ + defer func() { + if r := recover(); r == nil { + t.Errorf("HashNToHashNoPad should panic on non-canonical input (value >= ORDER)") + } + }() + + HashNToHashNoPad([]g.GoldilocksField{ g.GoldilocksField(g.ORDER + 1), g.GoldilocksField(g.ORDER + 2), g.GoldilocksField(g.ORDER + 3), g.GoldilocksField(math.MaxUint64), g.GoldilocksField(math.MaxUint64 - 1), }) - - expected := HashOut{ - 14216040864787980138, - 17275303675000904868, - 11831395338463193314, - 281267649235863375, - } - - for i := 0; i < 4; i++ { - if res[i] != expected[i] { - t.Logf("Expected: %v, got: %v\n", expected, res) - t.FailNow() - } - } } func TestHashTwoToOne(t *testing.T) { @@ -301,3 +292,83 @@ func TestConstantsAreInTheField(t *testing.T) { } } } + +func TestHashNToMNoPadBytesMatchesFieldHashSingleChunk(t *testing.T) { + input := []byte{1, 2, 3, 4, 5, 6, 7} + expected := HashNToMNoPad(bytesToFieldElements(input), 4) + result := HashNToMNoPadBytes(input, 4) + + compareFieldSlices(t, result, expected) +} + +func TestHashNToMNoPadBytesMatchesFieldHashPartialChunk(t *testing.T) { + input := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9} + expected := HashNToMNoPad(bytesToFieldElements(input), 4) + result := HashNToMNoPadBytes(input, 4) + + compareFieldSlices(t, result, expected) +} + +func compareFieldSlices(t *testing.T, got, want []g.GoldilocksField) { + if len(got) != len(want) { + t.Fatalf("expected %d output elements, got %d", len(want), len(got)) + } + + for i := range got { + if got[i] != want[i] { + t.Fatalf("expected element %d to be %d, got %d", i, want[i], got[i]) + } + } +} + +func bytesToFieldElements(input []byte) []g.GoldilocksField { + absorbLen := g.Bytes - 1 + if len(input) == 0 { + return nil + } + + chunkCount := (len(input) + absorbLen - 1) / absorbLen + fields := make([]g.GoldilocksField, chunkCount) + + for i := 0; i < chunkCount; i++ { + start := i * absorbLen + end := start + absorbLen + if end > len(input) { + end = len(input) + } + + var paddedChunk [g.Bytes]byte + copy(paddedChunk[:], input[start:end]) + fields[i] = g.FromCanonicalLittleEndianBytesF(paddedChunk[:]) + } + + return fields +} + +func TestHashNToMCanonicalBytesMatchesFieldHash(t *testing.T) { + inputFields := []g.GoldilocksField{ + 1, 2, 3, 4, 5, 6, 7, 8, + } + inputBytes := make([]byte, len(inputFields)*g.Bytes) + for i, f := range inputFields { + copy(inputBytes[i*g.Bytes:(i+1)*g.Bytes], g.ToLittleEndianBytesF(f)) + } + + expected := HashNToMNoPad(inputFields, 4) + result := HashNToMCanonicalBytes(inputBytes, 4) + + compareFieldSlices(t, result, expected) +} + +func TestHashNToMCanonicalBytesPanicsOnNonCanonical(t *testing.T) { + input := make([]byte, g.Bytes) + binary.LittleEndian.PutUint64(input, g.ORDER) + + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected panic for non-canonical field element") + } + }() + + _ = HashNToMCanonicalBytes(input, 1) +} diff --git a/signature/schnorr/schnorr.go b/signature/schnorr/schnorr.go index 25e463c..eeabe6f 100644 --- a/signature/schnorr/schnorr.go +++ b/signature/schnorr/schnorr.go @@ -93,7 +93,37 @@ func SchnorrPkFromSk(sk curve.ECgFp5Scalar) gFp5.Element { return curve.GENERATOR_ECgFp5Point.Mul(sk).Encode() } -func SchnorrSignHashedMessage(hashedMsg gFp5.Element, sk curve.ECgFp5Scalar) Signature { +// Sign the bytes; with the assumption that each 8 bytes is mapped to a canonical Field elements +func SchnorrSignCanonicalBytes(msg []byte, sk curve.ECgFp5Scalar) Signature { + // Hash bytes directly to quintic extension (5 field elements) + msgElements := p2.HashNToMCanonicalBytes(msg, 5) + // we are sure msgElements are in the cononical form already + hashedMsg := gFp5.FromPlonky2GoldilocksField(msgElements) + // Sign the hashed message + return schnorrSignHashedMessage(hashedMsg, sk) +} + +// Signing arbitrary length bytes. +func SchnorrSignBytes(msg []byte, sk curve.ECgFp5Scalar) Signature { + // Hash bytes directly to quintic extension (5 field elements) + msgElements := p2.HashNToMPadBytes(msg, 5) + // we are sure msgElements are in the cononical form already + hashedMsg := gFp5.FromPlonky2GoldilocksField(msgElements) + // Sign the hashed message + return schnorrSignHashedMessage(hashedMsg, sk) +} + +// Signing field elements. Panics if the message are not in canonical form. +func SchnorrSignFieldElements(msgElements []g.GoldilocksField, sk curve.ECgFp5Scalar) Signature { + // Hash field elements to quintic extension + // panic if the msgElements is not in canonical form + hashedMsg := p2.HashToQuinticExtension(msgElements) + // Sign the hashed message + return schnorrSignHashedMessage(hashedMsg, sk) +} + +// Core signing function. hashedMsg are ensured to be in the right form. +func schnorrSignHashedMessage(hashedMsg gFp5.Element, sk curve.ECgFp5Scalar) Signature { // Sample random scalar `k` and compute `r = k * G` k := curve.SampleScalar() r := curve.GENERATOR_ECgFp5Point.Mul(k).Encode() diff --git a/signature/schnorr/schnorr_test.go b/signature/schnorr/schnorr_test.go index 4b03588..d07866a 100644 --- a/signature/schnorr/schnorr_test.go +++ b/signature/schnorr/schnorr_test.go @@ -19,28 +19,51 @@ func TestSchnorrSignAndVerify(t *testing.T) { } hashedMsg := p2.HashToQuinticExtension(msg) - sig := SchnorrSignHashedMessage(hashedMsg, sk) + sig := SchnorrSignFieldElements(msg, sk) pk := SchnorrPkFromSk(sk) if !IsSchnorrSignatureValid(pk, hashedMsg, sig) { t.Fatalf("Signature is invalid") } } +func TestSchnorrSignBytesAndVerify(t *testing.T) { + sk := curve.SampleScalar() // Sample a secret key + msgBytes := make([]byte, 244*8) // 244 field elements * 8 bytes each + for i := range msgBytes { + msgBytes[i] = byte(i % 256) + } + + sig := SchnorrSignBytes(msgBytes, sk) + pk := SchnorrPkFromSk(sk) + + // Hash the message the same way as SchnorrSignBytes does + msgElements := p2.HashNToMNoPadBytes(msgBytes, 5) + hashedMsg := gFp5.FromPlonky2GoldilocksField(msgElements) + + if !IsSchnorrSignatureValid(pk, hashedMsg, sig) { + t.Fatalf("Signature is invalid") + } +} + func FuzzTestSchnorrSignAndVerify(f *testing.F) { f.Add([]byte{1, 2, 3, 4}, []byte{5, 6, 7, 8}) - f.Fuzz(func(t *testing.T, a, b []byte) { - scalar := curve.FromNonCanonicalBigInt(new(big.Int).SetBytes(a)) - - msgBytes := make([]g.GoldilocksField, 0) - for i := 0; i < len(b); i += 8 { - var chunk [8]byte - copy(chunk[:], b[i:min(i+8, len(b))]) - msgBytes = append(msgBytes, g.GoldilocksField(binary.LittleEndian.Uint64(chunk[:]))) - } - hashedMsg := p2.HashToQuinticExtension(msgBytes) + f.Fuzz(func(t *testing.T, a, b []byte) { + scalar := curve.FromNonCanonicalBigInt(new(big.Int).SetBytes(a)) + + msgBytes := make([]g.GoldilocksField, 0) + for i := 0; i < len(b); i += 8 { + var chunk [8]byte + copy(chunk[:], b[i:min(i+8, len(b))]) + val := binary.LittleEndian.Uint64(chunk[:]) + if val >= g.ORDER { + val -= g.ORDER + } + msgBytes = append(msgBytes, g.GoldilocksField(val)) + } + hashedMsg := p2.HashToQuinticExtension(msgBytes) - sig := SchnorrSignHashedMessage(hashedMsg, scalar) + sig := SchnorrSignFieldElements(msgBytes, scalar) pk := SchnorrPkFromSk(scalar) if !IsSchnorrSignatureValid(pk, hashedMsg, sig) { t.Fatalf("Signature is invalid") @@ -175,7 +198,7 @@ func TestBytes(t *testing.T) { } hashedMsg := p2.HashToQuinticExtension(msg) // Random message - sig := SchnorrSignHashedMessage(hashedMsg, sk) + sig := SchnorrSignFieldElements(msg, sk) sig2, err := SigFromBytes(sig.ToBytes()) if err != nil { t.Fatalf("Failed to convert signature bytes to Schnorr signature: %v", err) @@ -222,10 +245,9 @@ func BenchmarkSignatureSign(b *testing.B) { for i := 0; i < 244; i++ { msg[i] = g.SampleF() } - hashedMsg := p2.HashToQuinticExtension(msg) b.ResetTimer() for i := 0; i < b.N; i++ { - _ = SchnorrSignHashedMessage(hashedMsg, sk) + _ = SchnorrSignFieldElements(msg, sk) } }