diff --git a/mpn.c b/mpn.c index 319f923..b8895c2 100644 --- a/mpn.c +++ b/mpn.c @@ -361,6 +361,49 @@ mpn_import (mp_ptr zp, mp_size_t *zsize, size_t count, int order, MPN_NORMALIZE (zp, *zsize); } +extern void __gmpn_powm(mp_limb_t *rp, const mp_limb_t *bp, mp_size_t bn, + const mp_limb_t *ep, mp_size_t en, const mp_limb_t *mp, + mp_size_t n, mp_limb_t *tp); +extern mp_size_t __gmpn_binvert_itch(mp_size_t n); +extern void __gmpn_powlo(mp_limb_t *rp, const mp_limb_t *bp, + const mp_limb_t *ep, mp_size_t en, mp_size_t n, + mp_limb_t *tp); +extern void __gmpn_binvert(mp_limb_t *rp, const mp_limb_t *up, mp_size_t n, + mp_limb_t *scratch); +extern void __gmpn_mullo_n(mp_limb_t *rp, const mp_limb_t *xp, + const mp_limb_t *yp, mp_size_t n); + +void mpn_powm(mp_limb_t *rp, const mp_limb_t *bp, mp_size_t bn, + const mp_limb_t *ep, mp_size_t en, const mp_limb_t *mp, + mp_size_t n, mp_limb_t *tp) +{ + __gmpn_powm(rp, bp, bn, ep, en, mp, n, tp); +} + +mp_size_t mpn_binvert_itch(mp_size_t n) +{ + return __gmpn_binvert_itch(n); +} + +void mpn_powlo(mp_limb_t *rp, const mp_limb_t *bp, + const mp_limb_t *ep, mp_size_t en, mp_size_t n, + mp_limb_t *tp) +{ + __gmpn_powlo(rp, bp, ep, en, n, tp); +} + +void mpn_binvert(mp_limb_t *rp, const mp_limb_t *up, mp_size_t n, + mp_limb_t *scratch) +{ + __gmpn_binvert(rp, up, n, scratch); +} + +void mpn_mullo_n(mp_limb_t *rp, const mp_limb_t *xp, + const mp_limb_t *yp, mp_size_t n) +{ + __gmpn_mullo_n(rp, xp, yp, n); +} + #ifdef __GNUC__ # pragma GCC diagnostic pop #endif diff --git a/mpn.h b/mpn.h index ae14df9..584cb3c 100644 --- a/mpn.h +++ b/mpn.h @@ -23,4 +23,25 @@ void * mpn_export(void *data, size_t *countp, int order, void mpn_import(mp_ptr zp, mp_size_t *zsize, size_t count, int order, size_t size, int endian, size_t nail, const void *data); +/* Compute r = b^e mod m. Requires that m is odd and e > 1. + Uses scratch space at tp of MAX(mpn_binvert_itch(n), 2n) limbs. */ +void mpn_powm(mp_limb_t *rp, const mp_limb_t *bp, mp_size_t bn, + const mp_limb_t *ep, mp_size_t en, const mp_limb_t *mp, + mp_size_t n, mp_limb_t *tp); +mp_size_t mpn_binvert_itch(mp_size_t n); + +/* Compute r = b^e mod B^n, B is the limb base. + Requires normalized e. Uses scratch space of 3n words in tp. */ +void mpn_powlo(mp_limb_t *rp, const mp_limb_t *bp, + const mp_limb_t *ep, mp_size_t en, mp_size_t n, + mp_limb_t *tp); + +/* Compute r = u^(-1) mod B^n, B is the limb base. */ +void mpn_binvert(mp_limb_t *rp, const mp_limb_t *up, mp_size_t n, + mp_limb_t *scratch); + +/* Multiply two n-limb numbers and return the low n limbs of their products. */ +void mpn_mullo_n(mp_limb_t *rp, const mp_limb_t *xp, + const mp_limb_t *yp, mp_size_t n); + #endif /* MPN_H */ diff --git a/tests/t-pow.c b/tests/t-pow.c index bbaef26..9b64224 100644 --- a/tests/t-pow.c +++ b/tests/t-pow.c @@ -23,12 +23,16 @@ zz_ref_powm(const zz_t *u, const zz_t *v, const zz_t *w, zz_t *r) return ZZ_MEM; } mpz_init(z); + mpz_abs(mw, mw); mpz_powm(z, mu, mv, mw); if (zz_set_mpz_t(z, r)) { mpz_clear(z); return ZZ_MEM; } mpz_clear(z); + if (zz_isneg(w) && !zz_iszero(r) && !zz_isneg(r)) { + return zz_add(w, r, r); + } return ZZ_OK; } @@ -62,10 +66,10 @@ check_powm_bulk(void) if (zz_init(&u) || zz_random(bs, true, &u)) { abort(); } - if (zz_init(&v) || zz_random(32, true, &v)) { + if (zz_init(&v) || zz_random(128, true, &v)) { abort(); } - if (zz_init(&w) || zz_random(bs, false, &w)) { + if (zz_init(&w) || zz_random(bs, true, &w)) { abort(); } if (zz_init(&z)) { @@ -82,6 +86,12 @@ check_powm_bulk(void) zz_clear(&r); } else { + if (zz_iszero(&w)) { + if (zz_powm(&u, &v, &w, &z) != ZZ_VAL) { + abort(); + } + goto clear; + } if (zz_ref_gcd(&u, &w, &z) || zz_cmp(&z, 1) == ZZ_EQ) { abort(); } @@ -131,6 +141,7 @@ check_powm_bulk(void) abort(); } } +clear: zz_clear(&u); zz_clear(&v); zz_clear(&w); @@ -168,6 +179,31 @@ check_powm_examples(void) if (zz_set(0, &w) || zz_powm(&u, &v, &w, &w) != ZZ_VAL) { abort(); } + if (zz_set(123, &u) || zz_set(321, &v) || zz_set(1, &w) + || zz_powm(&u, &v, &w, &w) || zz_cmp(&w, 0)) + { + abort(); + } + if (zz_set(123, &u) || zz_set(0, &v) || zz_set(321, &w) + || zz_powm(&u, &v, &w, &w) || zz_cmp(&w, 1)) + { + abort(); + } + if (zz_set(0, &u) || zz_set(321, &v) || zz_set(123, &w) + || zz_powm(&u, &v, &w, &w) || zz_cmp(&w, 0)) + { + abort(); + } + if (zz_set(1, &u) || zz_set(321, &v) || zz_set(123, &w) + || zz_powm(&u, &v, &w, &w) || zz_cmp(&w, 1)) + { + abort(); + } + if (zz_set(321, &u) || zz_set(1, &v) || zz_set(123, &w) + || zz_powm(&u, &v, &w, &w) || zz_cmp(&w, 75)) + { + abort(); + } zz_clear(&u); zz_clear(&v); zz_clear(&w); diff --git a/zz.c b/zz.c index 1140fb6..c468044 100644 --- a/zz.c +++ b/zz.c @@ -2305,6 +2305,146 @@ zz_lcm(const zz_t *u, const zz_t *v, zz_t *w) return ret; } +static zz_err +_zz_powm(const zz_t *u, const zz_t *v, const zz_t *w, zz_t *res) +{ + if (zz_resize(w->size, res)) { + return ZZ_MEM; /* LCOV_EXCL_LINE */ + } + SETNEG(false, res); + /* Handle (u**1 mod w) early, since mpn_pow* can't */ + if (zz_cmp_i64(v, 1) == ZZ_EQ) { + if (zz_div(u, w, NULL, res)) { + return ZZ_MEM; /* LCOV_EXCL_LINE */ + } + return ZZ_OK; + } + + zz_bitcnt_t lsbpos = zz_lsbpos(w); + zz_size_t n = w->size; + zz_t t1; + + zz_init(&t1); + if (lsbpos) { + if (zz_quo_2exp(w, lsbpos, &t1)) { + /* LCOV_EXCL_START */ + zz_clear(&t1); + return ZZ_MEM; + /* LCOV_EXCL_STOP */ + } + w = &t1; + } + + zz_size_t nodd = w->size; + zz_size_t neven = (zz_size_t)(lsbpos + ZZ_DIGIT_T_BITS - 1)/ZZ_DIGIT_T_BITS; + zz_size_t cnt = lsbpos % ZZ_DIGIT_T_BITS; + zz_size_t n_largest_binvert = MAX(neven, nodd); + zz_size_t itch_binvert = mpn_binvert_itch(n_largest_binvert); + zz_size_t itch = n + MAX(itch_binvert, 2*n); + + /* Now w factored as w * BASE**neven */ + if (neven != 0) { + /* We will call both mpn_powm() and mpn_powlo() */ + itch += 2*n; + } + + zz_digit_t *volatile tp = malloc((size_t)itch * sizeof(zz_digit_t)); + zz_digit_t *volatile newup = NULL; + zz_digit_t *volatile newwp = NULL; + zz_digit_t *volatile rp = tp; + + if (!tp || TMP_OVERFLOW) { + /* LCOV_EXCL_START */ +clear: + free(rp); + free(newup); + free(newwp); + zz_clear(&t1); + zz_clear(res); + return ZZ_MEM; + /* LCOV_EXCL_STOP */ + } + tp += n; + /* Compute r = u**v mod w */ + mpn_powm (rp, u->digits, u->size, v->digits, v->size, + w->digits, nodd, tp); + if (neven != 0) { + zz_digit_t *r2, *xp, *yp, *odd_inv_2exp, *up, *wp; + + if (u->size < neven) { + /* Padd u with zeros. */ + newup = malloc((size_t)neven * sizeof(zz_digit_t)); + if (!newup) { + goto clear; /* LCOV_EXCL_LINE */ + } + mpn_copyi(newup, u->digits, u->size); + mpn_zero(newup + u->size, neven - u->size); + up = newup; + } + else { + up = u->digits; + } + r2 = tp; + if (up[0] % 2 == 0) { + if (v->size > 1) { + mpn_zero(r2, neven); + goto zero; + } + } + /* Compute r2 = u**v mod BASE**neven */ + mpn_powlo(r2, up, v->digits, v->size, neven, tp + neven); +zero: + free(newup); + if (nodd < neven) { + /* Padd w with zeros */ + newwp = malloc((size_t)neven * sizeof(zz_digit_t)); + if (!newwp) { + goto clear; /* LCOV_EXCL_LINE */ + } + mpn_copyi(newwp, w->digits, nodd); + mpn_zero(newwp + nodd, neven - nodd); + wp = newwp; + zz_clear(&t1); + } + else { + wp = w->digits; + } + odd_inv_2exp = tp + n; + /* odd_inv_2exp = w**(-1) mod BASE**neven */ + mpn_binvert(odd_inv_2exp, wp, neven, tp + 2*n); + /* r2 = r2 - r */ + mpn_sub(r2, r2, neven, rp, MIN(nodd, neven)); + xp = tp + 2*n; + /* x = (odd_inv_2exp * r2) mod BASE**neven */ + mpn_mullo_n(xp, odd_inv_2exp, r2, neven); + if (cnt) { + xp[neven - 1] &= ((mp_limb_t)1 << cnt) - 1; + } + yp = tp; + if (neven > nodd) { + mpn_mul(yp, xp, neven, wp, nodd); + } + else { + mpn_mul(yp, wp, nodd, xp, neven); + } + free(newwp); + /* r += x * w */ + mpn_add(rp, yp, n, rp, nodd); + } + zz_clear(&t1); + if (zz_resize(n, res)) { + /* LCOV_EXCL_START */ + free(rp); + zz_clear(res); + return ZZ_MEM; + /* LCOV_EXCL_STOP */ + } + mpn_copyi(res->digits, rp, n); + free(rp); + zz_normalize(res); + return ZZ_OK; +} + zz_err zz_powm(const zz_t *u, const zz_t *v, const zz_t *w, zz_t *res) { @@ -2357,56 +2497,83 @@ zz_powm(const zz_t *u, const zz_t *v, const zz_t *w, zz_t *res) return ret; } - zz_t o1, o2; + int negativeOutput = 0; + zz_t o1, o2, o3; + zz_err ret = ZZ_OK; - if (zz_init(&o1) || zz_init(&o2)) { + if (zz_init(&o1) || zz_init(&o2) || zz_init(&o3)) { /* LCOV_EXCL_START */ -mem: zz_clear(&o1); zz_clear(&o2); + zz_clear(&o3); return ZZ_MEM; /* LCOV_EXCL_STOP */ } + if (ISNEG(w)) { + if (zz_pos(w, &o3)) { + goto end; /* LCOV_EXCL_LINE */ + } + negativeOutput = 1; + SETNEG(false, &o3); + w = &o3; + } if (ISNEG(v)) { - zz_err ret = zz_inverse(u, w, &o2); + if (zz_pos(v, &o2)) { + goto end; /* LCOV_EXCL_LINE */ + } + SETNEG(false, &o2); + v = &o2; + if ((ret = zz_inverse(u, w, &o1)) == ZZ_MEM) { + goto end; /* LCOV_EXCL_LINE */ + } if (ret == ZZ_VAL) { +end: zz_clear(&o1); zz_clear(&o2); + zz_clear(&o3); return ZZ_VAL; } - if (ret == ZZ_MEM || zz_abs(v, &o1)) { - goto mem; /* LCOV_EXCL_LINE */ + u = &o1; + } + if (ISNEG(u) || u->size > w->size) { + zz_t tmp; + + if (zz_init(&tmp) || zz_div(u, w, NULL, &tmp) + || zz_pos(&tmp, &o1)) + { + /* LCOV_EXCL_START */ + zz_clear(&tmp); + goto end; + /* LCOV_EXCL_STOP */ } - u = &o2; - v = &o1; + zz_clear(&tmp); + u = &o1; } - if (u->size > INT_MAX || v->size > INT_MAX || w->size > INT_MAX) { - return ZZ_MEM; /* LCOV_EXCL_LINE */ + if (zz_cmp_i64(w, 1) == ZZ_EQ) { + ret = zz_set_i64(0, res); } - - mpz_t z; - TMP_MPZ(b, u) - TMP_MPZ(e, v) - TMP_MPZ(m, w) - if (TMP_OVERFLOW) { - return ZZ_MEM; /* LCOV_EXCL_LINE */ + else if (!v->size) { + ret = zz_set_i64(1, res); } - mpz_init(z); - mpz_powm(z, b, e, m); - if (zz_set_mpz_t(z, res)) { - /* LCOV_EXCL_START */ - mpz_clear(z); - goto mem; - /* LCOV_EXCL_STOP */ + else if (!u->size) { + ret = zz_set_i64(0, res); } - mpz_clear(z); - if (ISNEG(w) && res->size && zz_add(w, res, res)) { - goto mem; /* LCOV_EXCL_LINE */ + else if (zz_cmp_i64(u, 1) == ZZ_EQ) { + ret = zz_set_i64(1, res); + } + else { + if (_zz_powm(u, v, w, res)) { + goto end; /* LCOV_EXCL_LINE */ + } + } + if (negativeOutput && !ret && res->size && zz_sub(res, w, res)) { + goto end; /* LCOV_EXCL_LINE */ } zz_clear(&o1); zz_clear(&o2); - return ZZ_OK; + zz_clear(&o3); + return ret; } zz_err