Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ vendor/
# Go workspace file
go.work
go.work.sum
.gocache

# env file
.env
65 changes: 63 additions & 2 deletions hash/poseidon2_goldilocks_plonky2/poseidon2.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package poseidon2_plonky2

import (
"encoding/binary"
"fmt"
"hash"

Expand Down Expand Up @@ -76,6 +77,11 @@ func HashNToHashNoPad(input []g.GoldilocksField) HashOut {
}

func HashNToMNoPad(input []g.GoldilocksField, numOutputs int) []g.GoldilocksField {
for i := 0; i < len(input); i++ {
if uint64(input[i]) >= g.ORDER {
panic("input contains non-canonical field element (value >= ORDER)")
}
}
var perm [WIDTH]g.GoldilocksField
for i := 0; i < len(input); i += RATE {
for j := 0; j < RATE && i+j < len(input); j++ {
Expand All @@ -96,18 +102,73 @@ func HashNToMNoPad(input []g.GoldilocksField, numOutputs int) []g.GoldilocksFiel
}
}

// we want to be able to hash arbitrary messages to a digest.
// we parse the input as follows:
//
// pad the input with 0s to a multiple of 7
// for each 7 bytes; convert them into a goldilocks field element
//
// this creates len(input)/7 field elements
//
// we use Poseidon hash function to absort it, in each permutation, we absorb
//
// RATE = 8
//
// field elements.
//
// Once we absorbed all field elements, return the first numOutputs field elements.
// If numOutputs > RATE, keep permute and append the output.
//
// Note if we convert 8 bytes into a field element, then in current setup,
// an attacker may find a collision for two different messages m1, m2 where
//
// u64(m1[:8]) = u64(m2[:8]) + Goldilocks::Order
func HashNToMPadBytes(input []byte, numOutputs int) []g.GoldilocksField {
absorbLen := g.Bytes - 1
if len(input) == 0 {
return HashNToMNoPad(nil, numOutputs)
}

chunkCount := (len(input) + absorbLen - 1) / absorbLen
fields := make([]g.GoldilocksField, chunkCount)
for i := 0; i < chunkCount; i++ {
start := i * absorbLen
end := start + absorbLen
if end > len(input) {
end = len(input)
}

var paddedChunk [g.Bytes]byte
copy(paddedChunk[:], input[start:end])
fields[i] = g.FromCanonicalLittleEndianBytesF(paddedChunk[:])
}

return HashNToMNoPad(fields, numOutputs)
}

// Backward-compatible wrapper: identical to HashNToMPadBytes.
func HashNToMNoPadBytes(input []byte, numOutputs int) []g.GoldilocksField {
return HashNToMPadBytes(input, numOutputs)
}

// Hash bytes to field elements. Assumes every 8 bytes maps to a canonical field element.
// Panics if the field element is non-canonical.
func HashNToMCanonicalBytes(input []byte, numOutputs int) []g.GoldilocksField {

if len(input)%g.Bytes != 0 {
panic("input length should be multiple of 8")
}

inputLen := len(input) / g.Bytes

var perm [WIDTH]g.GoldilocksField
for i := 0; i < inputLen; i += RATE {
for j := 0; j < RATE && i+j < inputLen; j++ {
index := (i + j) * g.Bytes
perm[j] = g.FromCanonicalLittleEndianBytesF(input[index : index+g.Bytes])
elem := binary.LittleEndian.Uint64(input[index : index+g.Bytes])
if elem >= g.ORDER {
panic("input contains non-canonical field element (value >= ORDER)")
}
perm[j] = g.GoldilocksField(elem)
}
Permute(&perm)
}
Expand Down
109 changes: 90 additions & 19 deletions hash/poseidon2_goldilocks_plonky2/poseidon2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package poseidon2_plonky2

import (
"bytes"
"encoding/binary"
"math"
"testing"

Expand Down Expand Up @@ -112,15 +113,13 @@ func TestDigest(t *testing.T) {
inputs[1][6] = 1
inputs[1][7] = 0

g1 := g.FromCanonicalLittleEndianBytesF(inputs[0]) // 289077004332300282
g2 := g.FromCanonicalLittleEndianBytesF(inputs[1]) // 289644378102298614

hFunc.Write(inputs[0])
hFunc.Write(inputs[1])

hash := hFunc.Sum(nil)

hash2Elems := HashNoPad([]g.GoldilocksField{g1, g2})
messageBytes := append(append([]byte{}, inputs[0]...), inputs[1]...)
hash2Elems := HashNToHashNoPad(bytesToFieldElements(messageBytes))
hash2 := hash2Elems.ToLittleEndianBytes()

if !bytes.Equal(hash, hash2) {
Expand Down Expand Up @@ -173,27 +172,19 @@ func TestHashNToHashNoPad(t *testing.T) {
}

func TestHashNToHashNoPadLarge(t *testing.T) {
res := HashNToHashNoPad([]g.GoldilocksField{
defer func() {
if r := recover(); r == nil {
t.Errorf("HashNToHashNoPad should panic on non-canonical input (value >= ORDER)")
}
}()

HashNToHashNoPad([]g.GoldilocksField{
g.GoldilocksField(g.ORDER + 1),
g.GoldilocksField(g.ORDER + 2),
g.GoldilocksField(g.ORDER + 3),
g.GoldilocksField(math.MaxUint64),
g.GoldilocksField(math.MaxUint64 - 1),
})

expected := HashOut{
14216040864787980138,
17275303675000904868,
11831395338463193314,
281267649235863375,
}

for i := 0; i < 4; i++ {
if res[i] != expected[i] {
t.Logf("Expected: %v, got: %v\n", expected, res)
t.FailNow()
}
}
}

func TestHashTwoToOne(t *testing.T) {
Expand Down Expand Up @@ -301,3 +292,83 @@ func TestConstantsAreInTheField(t *testing.T) {
}
}
}

func TestHashNToMNoPadBytesMatchesFieldHashSingleChunk(t *testing.T) {
input := []byte{1, 2, 3, 4, 5, 6, 7}
expected := HashNToMNoPad(bytesToFieldElements(input), 4)
result := HashNToMNoPadBytes(input, 4)

compareFieldSlices(t, result, expected)
}

func TestHashNToMNoPadBytesMatchesFieldHashPartialChunk(t *testing.T) {
input := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}
expected := HashNToMNoPad(bytesToFieldElements(input), 4)
result := HashNToMNoPadBytes(input, 4)

compareFieldSlices(t, result, expected)
}

func compareFieldSlices(t *testing.T, got, want []g.GoldilocksField) {
if len(got) != len(want) {
t.Fatalf("expected %d output elements, got %d", len(want), len(got))
}

for i := range got {
if got[i] != want[i] {
t.Fatalf("expected element %d to be %d, got %d", i, want[i], got[i])
}
}
}

func bytesToFieldElements(input []byte) []g.GoldilocksField {
absorbLen := g.Bytes - 1
if len(input) == 0 {
return nil
}

chunkCount := (len(input) + absorbLen - 1) / absorbLen
fields := make([]g.GoldilocksField, chunkCount)

for i := 0; i < chunkCount; i++ {
start := i * absorbLen
end := start + absorbLen
if end > len(input) {
end = len(input)
}

var paddedChunk [g.Bytes]byte
copy(paddedChunk[:], input[start:end])
fields[i] = g.FromCanonicalLittleEndianBytesF(paddedChunk[:])
}

return fields
}

func TestHashNToMCanonicalBytesMatchesFieldHash(t *testing.T) {
inputFields := []g.GoldilocksField{
1, 2, 3, 4, 5, 6, 7, 8,
}
inputBytes := make([]byte, len(inputFields)*g.Bytes)
for i, f := range inputFields {
copy(inputBytes[i*g.Bytes:(i+1)*g.Bytes], g.ToLittleEndianBytesF(f))
}

expected := HashNToMNoPad(inputFields, 4)
result := HashNToMCanonicalBytes(inputBytes, 4)

compareFieldSlices(t, result, expected)
}

func TestHashNToMCanonicalBytesPanicsOnNonCanonical(t *testing.T) {
input := make([]byte, g.Bytes)
binary.LittleEndian.PutUint64(input, g.ORDER)

defer func() {
if r := recover(); r == nil {
t.Fatalf("expected panic for non-canonical field element")
}
}()

_ = HashNToMCanonicalBytes(input, 1)
}
32 changes: 31 additions & 1 deletion signature/schnorr/schnorr.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,37 @@ func SchnorrPkFromSk(sk curve.ECgFp5Scalar) gFp5.Element {
return curve.GENERATOR_ECgFp5Point.Mul(sk).Encode()
}

func SchnorrSignHashedMessage(hashedMsg gFp5.Element, sk curve.ECgFp5Scalar) Signature {
// Sign the bytes; with the assumption that each 8 bytes is mapped to a canonical Field elements
func SchnorrSignCanonicalBytes(msg []byte, sk curve.ECgFp5Scalar) Signature {
// Hash bytes directly to quintic extension (5 field elements)
msgElements := p2.HashNToMCanonicalBytes(msg, 5)
// we are sure msgElements are in the cononical form already
hashedMsg := gFp5.FromPlonky2GoldilocksField(msgElements)
// Sign the hashed message
return schnorrSignHashedMessage(hashedMsg, sk)
}

// Signing arbitrary length bytes.
func SchnorrSignBytes(msg []byte, sk curve.ECgFp5Scalar) Signature {
// Hash bytes directly to quintic extension (5 field elements)
msgElements := p2.HashNToMPadBytes(msg, 5)
// we are sure msgElements are in the cononical form already
hashedMsg := gFp5.FromPlonky2GoldilocksField(msgElements)
// Sign the hashed message
return schnorrSignHashedMessage(hashedMsg, sk)
}

// Signing field elements. Panics if the message are not in canonical form.
func SchnorrSignFieldElements(msgElements []g.GoldilocksField, sk curve.ECgFp5Scalar) Signature {
// Hash field elements to quintic extension
// panic if the msgElements is not in canonical form
hashedMsg := p2.HashToQuinticExtension(msgElements)
// Sign the hashed message
return schnorrSignHashedMessage(hashedMsg, sk)
}

// Core signing function. hashedMsg are ensured to be in the right form.
func schnorrSignHashedMessage(hashedMsg gFp5.Element, sk curve.ECgFp5Scalar) Signature {
// Sample random scalar `k` and compute `r = k * G`
k := curve.SampleScalar()
r := curve.GENERATOR_ECgFp5Point.Mul(k).Encode()
Expand Down
52 changes: 37 additions & 15 deletions signature/schnorr/schnorr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,51 @@ func TestSchnorrSignAndVerify(t *testing.T) {
}
hashedMsg := p2.HashToQuinticExtension(msg)

sig := SchnorrSignHashedMessage(hashedMsg, sk)
sig := SchnorrSignFieldElements(msg, sk)
pk := SchnorrPkFromSk(sk)
if !IsSchnorrSignatureValid(pk, hashedMsg, sig) {
t.Fatalf("Signature is invalid")
}
}

func TestSchnorrSignBytesAndVerify(t *testing.T) {
sk := curve.SampleScalar() // Sample a secret key
msgBytes := make([]byte, 244*8) // 244 field elements * 8 bytes each
for i := range msgBytes {
msgBytes[i] = byte(i % 256)
}

sig := SchnorrSignBytes(msgBytes, sk)
pk := SchnorrPkFromSk(sk)

// Hash the message the same way as SchnorrSignBytes does
msgElements := p2.HashNToMNoPadBytes(msgBytes, 5)
hashedMsg := gFp5.FromPlonky2GoldilocksField(msgElements)

if !IsSchnorrSignatureValid(pk, hashedMsg, sig) {
t.Fatalf("Signature is invalid")
}
}

func FuzzTestSchnorrSignAndVerify(f *testing.F) {
f.Add([]byte{1, 2, 3, 4}, []byte{5, 6, 7, 8})

f.Fuzz(func(t *testing.T, a, b []byte) {
scalar := curve.FromNonCanonicalBigInt(new(big.Int).SetBytes(a))

msgBytes := make([]g.GoldilocksField, 0)
for i := 0; i < len(b); i += 8 {
var chunk [8]byte
copy(chunk[:], b[i:min(i+8, len(b))])
msgBytes = append(msgBytes, g.GoldilocksField(binary.LittleEndian.Uint64(chunk[:])))
}
hashedMsg := p2.HashToQuinticExtension(msgBytes)
f.Fuzz(func(t *testing.T, a, b []byte) {
scalar := curve.FromNonCanonicalBigInt(new(big.Int).SetBytes(a))

msgBytes := make([]g.GoldilocksField, 0)
for i := 0; i < len(b); i += 8 {
var chunk [8]byte
copy(chunk[:], b[i:min(i+8, len(b))])
val := binary.LittleEndian.Uint64(chunk[:])
if val >= g.ORDER {
val -= g.ORDER
}
msgBytes = append(msgBytes, g.GoldilocksField(val))
}
hashedMsg := p2.HashToQuinticExtension(msgBytes)

sig := SchnorrSignHashedMessage(hashedMsg, scalar)
sig := SchnorrSignFieldElements(msgBytes, scalar)
pk := SchnorrPkFromSk(scalar)
if !IsSchnorrSignatureValid(pk, hashedMsg, sig) {
t.Fatalf("Signature is invalid")
Expand Down Expand Up @@ -175,7 +198,7 @@ func TestBytes(t *testing.T) {
}
hashedMsg := p2.HashToQuinticExtension(msg) // Random message

sig := SchnorrSignHashedMessage(hashedMsg, sk)
sig := SchnorrSignFieldElements(msg, sk)
sig2, err := SigFromBytes(sig.ToBytes())
if err != nil {
t.Fatalf("Failed to convert signature bytes to Schnorr signature: %v", err)
Expand Down Expand Up @@ -222,10 +245,9 @@ func BenchmarkSignatureSign(b *testing.B) {
for i := 0; i < 244; i++ {
msg[i] = g.SampleF()
}
hashedMsg := p2.HashToQuinticExtension(msg)

b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = SchnorrSignHashedMessage(hashedMsg, sk)
_ = SchnorrSignFieldElements(msg, sk)
}
}