Skip to content

Commit d32b17c

Browse files
committed
adding circuit test: binary range proof
1 parent a3ad2ec commit d32b17c

File tree

1 file changed

+148
-1
lines changed

1 file changed

+148
-1
lines changed

circuit_test.go

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ func TestArithmeticCircuit2(t *testing.T) {
185185
Fm: false,
186186

187187
F: func(typ PartitionType, index int) *int {
188-
if typ == PartitionLL { // map all to no
188+
if typ == PartitionLL { // map all to ll
189189
return &index
190190
}
191191

@@ -217,6 +217,153 @@ func TestArithmeticCircuit2(t *testing.T) {
217217
}
218218
}
219219

220+
func TestArithmeticCircuitBinaryRangeProof(t *testing.T) {
221+
value := []*big.Int{bint(0), bint(1), bint(1), bint(0)} // bin(0110) = dec(6)
222+
// We have prove that value < 2^n - 1
223+
// Proving of the bits count is automatic (public parameters dimension will not allow to verify prove for bigger value)
224+
// Then we should prove that every value is a bit
225+
// To do that have to prove that every value[i] * (value[i] - 1) = 0
226+
227+
// For the 4-bit value we will have te following constraints:
228+
// value0*value0 = a0
229+
// a0 - value0 = 0
230+
231+
// value1*value1 = a1
232+
// a1 - value1 = 0
233+
234+
// value2*value2 = a2
235+
// a2 - x2 = 0
236+
237+
// value3*value3 = a3
238+
// a3 - value3 = 0
239+
240+
Nm := 4
241+
No := 4
242+
Nv := 2
243+
K := 4
244+
245+
Nl := Nv * K // 8
246+
Nw := Nm + Nm + No // 12
247+
248+
a := hadamardMul(value, value) // a[i] = value[i] * value[i]
249+
250+
v := [][]*big.Int{
251+
{value[0], a[0]},
252+
{value[1], a[1]},
253+
{value[2], a[2]},
254+
{value[3], a[3]},
255+
}
256+
257+
wl := value
258+
wr := value
259+
wo := a
260+
261+
w := append(wl, wr...)
262+
w = append(w, wo...) // w = wl||wl||wo
263+
264+
wv := make([]*big.Int, 0, Nw)
265+
for i := range v {
266+
wv = append(wv, v[i]...)
267+
}
268+
269+
Wm := [][]*big.Int{
270+
{bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(1), bint(0), bint(0), bint(0)},
271+
{bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(1), bint(0), bint(0)},
272+
{bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(1), bint(0)},
273+
{bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(1)},
274+
} // Nm*Nw = 4 * 12
275+
276+
Am := zeroVector(Nm) // Nm
277+
278+
Wl := [][]*big.Int{
279+
{bint(-1), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0)},
280+
{bint(-1), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0)},
281+
{bint(0), bint(-1), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0)},
282+
{bint(0), bint(-1), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0)},
283+
{bint(0), bint(0), bint(-1), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0)},
284+
{bint(0), bint(0), bint(-1), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0)},
285+
{bint(0), bint(0), bint(0), bint(-1), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0)},
286+
{bint(0), bint(0), bint(0), bint(-1), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0), bint(0)},
287+
} // Nl*Nw = 8 * 12
288+
289+
Al := zeroVector(Nl)
290+
291+
fmt.Println("Circuit check:", matrixMulOnVector(w, Wm), "=", hadamardMul(wl, wr))
292+
fmt.Println("Circuit check:", vectorAdd(vectorAdd(matrixMulOnVector(w, Wl), wv), Al), "= 0")
293+
294+
wnla := NewWeightNormLinearPublic(16, Nm)
295+
296+
public := &ArithmeticCircuitPublic{
297+
Nm: Nm,
298+
Nl: Nl,
299+
Nv: Nv,
300+
Nw: Nw,
301+
No: No,
302+
K: K,
303+
304+
G: wnla.G,
305+
GVec: wnla.GVec[:Nm],
306+
HVec: wnla.HVec[:9+Nv],
307+
308+
Wm: Wm,
309+
Wl: Wl,
310+
Am: Am,
311+
Al: Al,
312+
Fl: true,
313+
Fm: false,
314+
315+
F: func(typ PartitionType, index int) *int {
316+
if typ == PartitionNO { // map all to no
317+
return &index
318+
}
319+
320+
return nil
321+
},
322+
323+
GVec_: wnla.GVec[Nm:],
324+
HVec_: wnla.HVec[9+Nv:],
325+
}
326+
327+
private := &ArithmeticCircuitPrivate{
328+
V: v,
329+
Sv: []*big.Int{MustRandScalar(), MustRandScalar(), MustRandScalar(), MustRandScalar()},
330+
Wl: wl,
331+
Wr: wr,
332+
Wo: wo,
333+
}
334+
335+
V := make([]*bn256.G1, public.K)
336+
for i := range V {
337+
V[i] = public.CommitCircuit(private.V[i], private.Sv[i])
338+
}
339+
340+
proof := ProveCircuit(public, NewKeccakFS(), private)
341+
spew.Dump(proof)
342+
343+
if err := VerifyCircuit(public, V, NewKeccakFS(), proof); err != nil {
344+
panic(err)
345+
}
346+
}
347+
348+
func matrixMulOnVector(a []*big.Int, m [][]*big.Int) []*big.Int {
349+
var res []*big.Int
350+
351+
for i := 0; i < len(m); i++ {
352+
res = append(res, vectorMul(a, m[i]))
353+
}
354+
355+
return res
356+
}
357+
358+
func hadamardMul(a, b []*big.Int) []*big.Int {
359+
res := make([]*big.Int, len(a))
360+
for i := range res {
361+
res[i] = mul(a[i], b[i])
362+
}
363+
364+
return res
365+
}
366+
220367
func frac(a, b int) *big.Int {
221368
return mul(bint(a), inv(bint(b)))
222369
}

0 commit comments

Comments
 (0)