Skip to content

Commit ca62f99

Browse files
committed
Refactor PRG to single impl
1 parent 1a8849e commit ca62f99

15 files changed

Lines changed: 249 additions & 355 deletions

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ use rand::prelude::*;
2828
// Matyas-Meyer-Oseas (via AES128) provides 128-bit security and should be enough.
2929
// Hirose (via AES256) still only provides 128-bit security because the output is not chained.
3030
// But Hirose can be helpful is you are forced to choose AES256.
31-
use fss_rs::dcf::prg::Aes128MatyasMeyerOseasPrg;
31+
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;
3232
use fss_rs::dcf::{Dcf, DcfImpl};
3333

3434
let keys: [[u8; 32]; 2] = thread_rng().gen();
35-
let prg = Aes128MatyasMeyerOseasPrg::<16, 2>::new(std::array::from_fn(|i| &keys[i]));
35+
let prg = Aes128MatyasMeyerOseasPrg::<16, 2, 2>::new(std::array::from_fn(|i| &keys[i]));
3636
// DCF for example
3737
let dcf = DcfImpl::<16, 16, _>::new(prg);
3838
```

benches/dcf_eval.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
55
use rand::prelude::*;
66

7-
use fss_rs::dcf::prg::Aes128MatyasMeyerOseasPrg;
87
use fss_rs::dcf::{BoundState, CmpFn, Dcf, DcfImpl};
98
use fss_rs::group::byte::ByteGroup;
109
use fss_rs::group::Group;
10+
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;
1111

1212
fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIPHER_N: usize>(
1313
c: &mut Criterion,
@@ -16,7 +16,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
1616
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
1717
let keys_iter = std::array::from_fn(|i| &keys[i]);
1818

19-
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
19+
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 2, CIPHER_N>::new(keys_iter);
2020
let dcf = DcfImpl::<IN_BLEN, OUT_BLEN, _>::new(prg);
2121

2222
let mut s0s = [[0; OUT_BLEN]; 2];

benches/dcf_eval_batch.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
55
use rand::prelude::*;
66

7-
use fss_rs::dcf::prg::Aes128MatyasMeyerOseasPrg;
87
use fss_rs::dcf::{BoundState, CmpFn, Dcf, DcfImpl};
98
use fss_rs::group::byte::ByteGroup;
109
use fss_rs::group::Group;
10+
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;
1111

1212
const POINT_NUM: usize = 10000;
1313

@@ -18,7 +18,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
1818
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
1919
let keys_iter = std::array::from_fn(|i| &keys[i]);
2020

21-
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
21+
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 2, CIPHER_N>::new(keys_iter);
2222
let dcf = DcfImpl::<IN_BLEN, OUT_BLEN, _>::new(prg);
2323

2424
let mut s0s = [[0; OUT_BLEN]; 2];

benches/dcf_full_eval.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
55
use rand::prelude::*;
66

7-
use fss_rs::dcf::prg::Aes128MatyasMeyerOseasPrg;
87
use fss_rs::dcf::{BoundState, CmpFn, Dcf, DcfImpl};
98
use fss_rs::group::byte::ByteGroup;
109
use fss_rs::group::Group;
10+
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;
1111

1212
fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIPHER_N: usize>(
1313
c: &mut Criterion,
@@ -17,7 +17,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
1717
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
1818
let keys_iter = std::array::from_fn(|i| &keys[i]);
1919

20-
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
20+
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 2, CIPHER_N>::new(keys_iter);
2121
let dcf = DcfImpl::<IN_BLEN, OUT_BLEN, _>::new_with_filter(prg, filter_bitn);
2222

2323
let mut s0s = [[0; OUT_BLEN]; 2];

benches/dcf_gen.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
55
use rand::prelude::*;
66

7-
use fss_rs::dcf::prg::Aes128MatyasMeyerOseasPrg;
87
use fss_rs::dcf::{BoundState, CmpFn, Dcf, DcfImpl};
98
use fss_rs::group::byte::ByteGroup;
9+
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;
1010

1111
fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIPHER_N: usize>(
1212
c: &mut Criterion,
@@ -15,7 +15,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
1515
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
1616
let keys_iter = std::array::from_fn(|i| &keys[i]);
1717

18-
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
18+
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 2, CIPHER_N>::new(keys_iter);
1919
let dcf = DcfImpl::<IN_BLEN, OUT_BLEN, _>::new(prg);
2020

2121
let mut s0s = [[0; OUT_BLEN]; 2];

benches/dpf_eval.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
55
use rand::prelude::*;
66

7-
use fss_rs::dpf::prg::Aes128MatyasMeyerOseasPrg;
87
use fss_rs::dpf::{Dpf, DpfImpl, PointFn};
98
use fss_rs::group::byte::ByteGroup;
109
use fss_rs::group::Group;
10+
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;
1111

1212
fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIPHER_N: usize>(
1313
c: &mut Criterion,
@@ -16,7 +16,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
1616
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
1717
let keys_iter = std::array::from_fn(|i| &keys[i]);
1818

19-
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
19+
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 1, CIPHER_N>::new(keys_iter);
2020
let dpf = DpfImpl::<IN_BLEN, OUT_BLEN, _>::new(prg);
2121

2222
let mut s0s = [[0; OUT_BLEN]; 2];

benches/dpf_eval_batch.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
55
use rand::prelude::*;
66

7-
use fss_rs::dpf::prg::Aes128MatyasMeyerOseasPrg;
87
use fss_rs::dpf::{Dpf, DpfImpl, PointFn};
98
use fss_rs::group::byte::ByteGroup;
109
use fss_rs::group::Group;
10+
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;
1111

1212
const POINT_NUM: usize = 10000;
1313

@@ -18,7 +18,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
1818
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
1919
let keys_iter = std::array::from_fn(|i| &keys[i]);
2020

21-
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
21+
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 1, CIPHER_N>::new(keys_iter);
2222
let dpf = DpfImpl::<IN_BLEN, OUT_BLEN, _>::new(prg);
2323

2424
let mut s0s = [[0; OUT_BLEN]; 2];

benches/dpf_full_eval.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
55
use rand::prelude::*;
66

7-
use fss_rs::dpf::prg::Aes128MatyasMeyerOseasPrg;
87
use fss_rs::dpf::{Dpf, DpfImpl, PointFn};
98
use fss_rs::group::byte::ByteGroup;
109
use fss_rs::group::Group;
10+
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;
1111

1212
fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIPHER_N: usize>(
1313
c: &mut Criterion,
@@ -17,7 +17,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
1717
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
1818
let keys_iter = std::array::from_fn(|i| &keys[i]);
1919

20-
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
20+
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 1, CIPHER_N>::new(keys_iter);
2121
let dpf = DpfImpl::<IN_BLEN, OUT_BLEN, _>::new_with_filter(prg, filter_bitn);
2222

2323
let mut s0s = [[0; OUT_BLEN]; 2];

benches/dpf_gen.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
55
use rand::prelude::*;
66

7-
use fss_rs::dpf::prg::Aes128MatyasMeyerOseasPrg;
87
use fss_rs::dpf::{Dpf, DpfImpl, PointFn};
98
use fss_rs::group::byte::ByteGroup;
9+
use fss_rs::prg::Aes128MatyasMeyerOseasPrg;
1010

1111
fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIPHER_N: usize>(
1212
c: &mut Criterion,
@@ -15,7 +15,7 @@ fn from_domain_range_size<const IN_BLEN: usize, const OUT_BLEN: usize, const CIP
1515
keys.iter_mut().for_each(|k| thread_rng().fill_bytes(k));
1616
let keys_iter = std::array::from_fn(|i| &keys[i]);
1717

18-
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, CIPHER_N>::new(keys_iter);
18+
let prg = Aes128MatyasMeyerOseasPrg::<OUT_BLEN, 1, CIPHER_N>::new(keys_iter);
1919
let dpf = DpfImpl::<IN_BLEN, OUT_BLEN, _>::new(prg);
2020

2121
let mut s0s = [[0; OUT_BLEN]; 2];

src/dcf/mod.rs

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@ use rayon::prelude::*;
99

1010
use crate::group::Group;
1111
use crate::utils::{xor, xor_inplace};
12-
use crate::{decl_prg_trait, Cw, PointFn, Share};
13-
14-
#[cfg(feature = "prg")]
15-
pub mod prg;
12+
use crate::{Cw, PointFn, Prg, Share};
1613

1714
/// Distributed comparison function API.
1815
///
@@ -64,22 +61,20 @@ where
6461
}
6562
}
6663

67-
decl_prg_trait!(([u8; OUT_BLEN], [u8; OUT_BLEN], bool));
68-
6964
/// [`Dcf`] impl.
7065
///
7166
/// `$\alpha$` itself is not included (or say exclusive endpoint), which means `$f(\alpha)$ = 0`.
7267
pub struct DcfImpl<const IN_BLEN: usize, const OUT_BLEN: usize, P>
7368
where
74-
P: Prg<OUT_BLEN>,
69+
P: Prg<OUT_BLEN, 2>,
7570
{
7671
prg: P,
7772
filter_bitn: usize,
7873
}
7974

8075
impl<const IN_BLEN: usize, const OUT_BLEN: usize, P> DcfImpl<IN_BLEN, OUT_BLEN, P>
8176
where
82-
P: Prg<OUT_BLEN>,
77+
P: Prg<OUT_BLEN, 2>,
8378
{
8479
pub fn new(prg: P) -> Self {
8580
Self {
@@ -100,7 +95,7 @@ const IDX_R: usize = 1;
10095
impl<const IN_BLEN: usize, const OUT_BLEN: usize, P, G> Dcf<IN_BLEN, OUT_BLEN, G>
10196
for DcfImpl<IN_BLEN, OUT_BLEN, P>
10297
where
103-
P: Prg<OUT_BLEN>,
98+
P: Prg<OUT_BLEN, 2>,
10499
G: Group<OUT_BLEN>,
105100
{
106101
fn gen(
@@ -119,8 +114,8 @@ where
119114
for i in 0..n {
120115
// MSB is required since we index from high to low in arrays.
121116
let alpha_i = f.alpha.view_bits::<Msb0>()[i];
122-
let [(s0l, v0l, t0l), (s0r, v0r, t0r)] = self.prg.gen(&ss_prev[0]);
123-
let [(s1l, v1l, t1l), (s1r, v1r, t1r)] = self.prg.gen(&ss_prev[1]);
117+
let [([s0l, v0l], t0l), ([s0r, v0r], t0r)] = self.prg.gen(&ss_prev[0]);
118+
let [([s1l, v1l], t1l), ([s1r, v1r], t1r)] = self.prg.gen(&ss_prev[1]);
124119
// MSB is required since we index from high to low in arrays.
125120
let (keep, lose) = if alpha_i {
126121
(IDX_R, IDX_L)
@@ -201,7 +196,7 @@ where
201196

202197
impl<const IN_BLEN: usize, const OUT_BLEN: usize, P> DcfImpl<IN_BLEN, OUT_BLEN, P>
203198
where
204-
P: Prg<OUT_BLEN>,
199+
P: Prg<OUT_BLEN, 2>,
205200
{
206201
/// Eval with single-threading.
207202
/// See [`Dcf::eval`].
@@ -255,7 +250,7 @@ where
255250

256251
let cw = &k.cws[layer_i];
257252
// `*_hat` before in-place XOR.
258-
let [(mut sl, vl_hat, mut tl), (mut sr, vr_hat, mut tr)] = self.prg.gen(&s);
253+
let [([mut sl, vl_hat], mut tl), ([mut sr, vr_hat], mut tr)] = self.prg.gen(&s);
259254
xor_inplace(&mut sl, &[if t { &cw.s } else { &[0; OUT_BLEN] }]);
260255
xor_inplace(&mut sr, &[if t { &cw.s } else { &[0; OUT_BLEN] }]);
261256
tl ^= t & cw.tl;
@@ -291,7 +286,7 @@ where
291286
for i in 0..n {
292287
let cw = &k.cws[i];
293288
// `*_hat` before in-place XOR.
294-
let [(mut sl, vl_hat, mut tl), (mut sr, vr_hat, mut tr)] = self.prg.gen(&s_prev);
289+
let [([mut sl, vl_hat], mut tl), ([mut sr, vr_hat], mut tr)] = self.prg.gen(&s_prev);
295290
xor_inplace(&mut sl, &[if t_prev { &cw.s } else { &[0; OUT_BLEN] }]);
296291
xor_inplace(&mut sr, &[if t_prev { &cw.s } else { &[0; OUT_BLEN] }]);
297292
tl ^= t_prev & cw.tl;
@@ -326,9 +321,9 @@ pub enum BoundState {
326321
mod tests {
327322
use rand::prelude::*;
328323

329-
use super::prg::Aes256HirosePrg;
330324
use super::*;
331325
use crate::group::byte::ByteGroup;
326+
use crate::prg::Aes256HirosePrg;
332327

333328
const KEYS: &[&[u8; 32]] = &[
334329
b"j9\x1b_\xb3X\xf33\xacW\x15\x1b\x0812K\xb3I\xb9\x90r\x1cN\xb5\xee9W\xd3\xbb@\xc6d",
@@ -345,7 +340,7 @@ mod tests {
345340

346341
#[test]
347342
fn test_dcf_gen_then_eval() {
348-
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
343+
let prg = Aes256HirosePrg::<16, 2, 2>::new(std::array::from_fn(|i| KEYS[i]));
349344
let dcf = DcfImpl::<16, 16, _>::new(prg);
350345
let s0s: [[u8; 16]; 2] = thread_rng().gen();
351346
let f = CmpFn {
@@ -377,7 +372,7 @@ mod tests {
377372

378373
#[test]
379374
fn test_dcf_gen_gt_beta_then_eval() {
380-
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
375+
let prg = Aes256HirosePrg::<16, 2, 2>::new(std::array::from_fn(|i| KEYS[i]));
381376
let dcf = DcfImpl::<16, 16, _>::new(prg);
382377
let s0s: [[u8; 16]; 2] = thread_rng().gen();
383378
let f = CmpFn {
@@ -409,7 +404,7 @@ mod tests {
409404

410405
#[test]
411406
fn test_dcf_gen_then_eval_with_filter() {
412-
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
407+
let prg = Aes256HirosePrg::<16, 2, 2>::new(std::array::from_fn(|i| KEYS[i]));
413408
let dcf = DcfImpl::<16, 16, _>::new_with_filter(prg, 127);
414409
let s0s: [[u8; 16]; 2] = thread_rng().gen();
415410
let f = CmpFn {
@@ -441,7 +436,7 @@ mod tests {
441436

442437
#[test]
443438
fn test_dcf_gen_then_eval_not_zeros() {
444-
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
439+
let prg = Aes256HirosePrg::<16, 2, 2>::new(std::array::from_fn(|i| KEYS[i]));
445440
let dcf = DcfImpl::<16, 16, _>::new(prg);
446441
let s0s: [[u8; 16]; 2] = thread_rng().gen();
447442
let f = CmpFn {
@@ -465,7 +460,7 @@ mod tests {
465460
#[test]
466461
fn test_dcf_full_eval() {
467462
let x: [u8; 2] = ALPHAS[2][..2].try_into().unwrap();
468-
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
463+
let prg = Aes256HirosePrg::<16, 2, 2>::new(std::array::from_fn(|i| KEYS[i]));
469464
let dcf = DcfImpl::<2, 16, _>::new(prg);
470465
let s0s: [[u8; 16]; 2] = thread_rng().gen();
471466
let f = CmpFn {
@@ -491,7 +486,7 @@ mod tests {
491486
#[test]
492487
fn test_dcf_full_eval_with_filter() {
493488
let x: [u8; 2] = ALPHAS[2][..2].try_into().unwrap();
494-
let prg = Aes256HirosePrg::<16, 2>::new(std::array::from_fn(|i| KEYS[i]));
489+
let prg = Aes256HirosePrg::<16, 2, 2>::new(std::array::from_fn(|i| KEYS[i]));
495490
let dcf = DcfImpl::<2, 16, _>::new_with_filter(prg, 15);
496491
let s0s: [[u8; 16]; 2] = thread_rng().gen();
497492
let f = CmpFn {

0 commit comments

Comments
 (0)