Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in ECAdd bloq #1489

Merged
merged 18 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 102 additions & 43 deletions qualtran/bloqs/cryptography/ecc/ec_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
bloq_example,
BloqBuilder,
BloqDocSpec,
CtrlSpec,
DecomposeTypeError,
QBit,
QMontgomeryUInt,
Expand All @@ -32,8 +33,8 @@
Soquet,
SoquetT,
)
from qualtran.bloqs.arithmetic.comparison import Equals
from qualtran.bloqs.basic_gates import CNOT, IntState, Toffoli, ZeroState
from qualtran.bloqs.arithmetic import Equals, Xor
from qualtran.bloqs.basic_gates import CNOT, IntState, Toffoli, XGate, ZeroState
from qualtran.bloqs.bookkeeping import Free
from qualtran.bloqs.mcmt import MultiAnd, MultiControlX, MultiTargetCNOT
from qualtran.bloqs.mod_arithmetic import (
Expand Down Expand Up @@ -192,6 +193,12 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
class _ECAddStepTwo(Bloq):
r"""Performs step two of the ECAdd bloq.

Includes a bugfix for the scenario where the calculated λ ( = (y - b) / (x - a)) is equivalent
to λ_r ( = 3 * a ^ 2 + c_1 / (2 * b)) and f_1 is wrongfully cleared. We accomplish this by
introducing a new ancilla qubit set by an equals operation on the computed λ and the classical,
pre-computed λ_r. We then control the equals bloq on this ancilla qubit which will only clear
the f_1 flag in the correct situation. Finally, we clear and free the ancilla afterwards.

Args:
n: The bitsize of the two registers storing the elliptic curve point
mod: The modulus of the field in which we do the addition.
Expand Down Expand Up @@ -253,10 +260,6 @@ def on_classical_vals(
lam = QMontgomeryUInt(self.n, self.mod).montgomery_product(
int(y), QMontgomeryUInt(self.n, self.mod).montgomery_inverse(int(x))
)
# TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit
# which flips f1 when lam and lam_r are equal.
if lam == lam_r:
f1 = (f1 + 1) % 2
else:
lam = 0
return {'f1': f1, 'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam, 'lam_r': lam_r}
Expand Down Expand Up @@ -296,6 +299,12 @@ def build_composite_bloq(
y=y,
)

# Allocate an ancilla qubit that acts as a flag for the rare condition that the
# pre-computed lambda_r is equal to the calculated lambda. This ancilla is used to properly
# clear the f1 qubit when lambda is set to lambda_r.
ancilla = bb.allocate()
z4, lam_r, ancilla = bb.add(Equals(QMontgomeryUInt(self.n)), x=z4, y=lam_r, target=ancilla)

# If ctrl = 1 and x != a: lam = (y - b) / (x - a) % p.
z4_split = bb.split(z4)
lam_split = bb.split(lam)
Expand Down Expand Up @@ -323,7 +332,18 @@ def build_composite_bloq(
lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n))

# If lam = lam_r: return f1 = 0. (If not we will flip f1 to 0 at the end iff x_r = y_r = 0).
lam, lam_r, f1 = bb.add(Equals(QMontgomeryUInt(self.n)), x=lam, y=lam_r, target=f1)
# Only flip when lam is set to lam_r.
ancilla, lam, lam_r, f1 = bb.add(
Equals(QMontgomeryUInt(self.n)).controlled(ctrl_spec=CtrlSpec(cvs=0)),
ctrl=ancilla,
x=lam,
y=lam_r,
target=f1,
)

# Clear the ancilla bit and free it.
z4, lam_r, ancilla = bb.add(Equals(QMontgomeryUInt(self.n)), x=z4, y=lam_r, target=ancilla)
bb.free(ancilla)

# Uncompute the modular multiplication then the modular inversion.
x, y = bb.add(
Expand All @@ -343,7 +363,8 @@ def build_composite_bloq(

def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
return {
Equals(QMontgomeryUInt(self.n)): 1,
Equals(QMontgomeryUInt(self.n)): 2,
Equals(QMontgomeryUInt(self.n)).controlled(ctrl_spec=CtrlSpec(cvs=0)): 1,
ModSub(QMontgomeryUInt(self.n), mod=self.mod): 1,
CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1,
KaliskiModInverse(bitsize=self.n, mod=self.mod): 1,
Expand Down Expand Up @@ -639,6 +660,13 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
class _ECAddStepFive(Bloq):
r"""Performs step five of the ECAdd bloq.

Includes a bugfix for the scenario where (a, b) = (x, y) and a - x_r = 0. In this situation,
f_1 is set and f_2 - f_4 is cleared (which means that the ctrl qubit is set). Because a - x_r
is 0, the computed λ is undefined (and with this construction the computed λ will be set to 0),
however the λ is non-zero and should be cleared with λ_r. We accomplish this with a controled
Xor bloq controlled on the ctrl qubit and the condition that the x register (a - x_r) = 0. In
this ase we clear the λ register with λ_r.

Args:
n: The bitsize of the two registers storing the elliptic curve point
mod: The modulus of the field in which we do the addition.
Expand All @@ -652,6 +680,7 @@ class _ECAddStepFive(Bloq):
will contain the x component of the resultant curve point.
y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which
will contain the y component of the resultant curve point.
lam_r: The precomputed lambda slope used in the addition operation if (a, b) = (x, y) in montgomery form.
lam: The lambda slope used in the addition operation.

References:
Expand All @@ -672,6 +701,7 @@ def signature(self) -> 'Signature':
Register('b', QMontgomeryUInt(self.n)),
Register('x', QMontgomeryUInt(self.n)),
Register('y', QMontgomeryUInt(self.n)),
Register('lam_r', QMontgomeryUInt(self.n)),
Register('lam', QMontgomeryUInt(self.n), side=Side.LEFT),
]
)
Expand All @@ -683,14 +713,15 @@ def on_classical_vals(
b: 'ClassicalValT',
x: 'ClassicalValT',
y: 'ClassicalValT',
lam_r: 'ClassicalValT',
lam: 'ClassicalValT',
) -> Dict[str, 'ClassicalValT']:
if ctrl == 1:
x = (a - x) % self.mod
y = (y - b) % self.mod
else:
x = (x + a) % self.mod
return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y}
return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam_r': lam_r}

def build_composite_bloq(
self,
Expand All @@ -700,6 +731,7 @@ def build_composite_bloq(
b: Soquet,
x: Soquet,
y: Soquet,
lam_r: Soquet,
lam: Soquet,
) -> Dict[str, 'SoquetT']:
if is_symbolic(self.n):
Expand Down Expand Up @@ -729,9 +761,15 @@ def build_composite_bloq(
z4_split[i] = ctrls[1]
z4 = bb.join(z4_split, dtype=QMontgomeryUInt(self.n))
lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n))
# TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit where lambda
# is not set to 0 before being freed.
bb.add(Free(QMontgomeryUInt(self.n), dirty=True), reg=lam)

# If the denominator of lambda is 0, lam = lam_r so we clear lam with lam_r.
clear_lam = (
Xor(QMontgomeryUInt(self.n))
.controlled(CtrlSpec(qdtypes=QMontgomeryUInt(self.n), cvs=0))
.controlled()
)
ctrl, x, lam_r, lam = bb.add(clear_lam, ctrl1=ctrl, ctrl2=x, x=lam_r, y=lam)
bb.add(Free(QMontgomeryUInt(self.n)), reg=lam)

# Uncompute multiplication and inverse.
x, y = bb.add(
Expand All @@ -756,9 +794,14 @@ def build_composite_bloq(
ctrl, b, y = bb.add(CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=b, y=y)

# Return the output registers.
return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y}
return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam_r': lam_r}

def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
clear_lam = (
Xor(QMontgomeryUInt(self.n))
.controlled(CtrlSpec(qdtypes=QMontgomeryUInt(self.n), cvs=0))
.controlled()
)
return {
CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1,
KaliskiModInverse(bitsize=self.n, mod=self.mod): 1,
Expand All @@ -771,6 +814,7 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
KaliskiModInverse(bitsize=self.n, mod=self.mod).adjoint(): 1,
ModAdd(self.n, mod=self.mod): 1,
MultiControlX(cvs=[1, 1]): self.n,
clear_lam: 1,
CModNeg(QMontgomeryUInt(self.n), mod=self.mod): 1,
}

Expand All @@ -779,6 +823,16 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
class _ECAddStepSix(Bloq):
r"""Performs step six of the ECAdd bloq.

Include bugfixes for the following scenarios:
1. f_2 is improperly cleared when ((x, y) = (0, 0) AND b = 0) OR ((a, b) = (0, 0) AND
y = 0).
2. f_4 is improperly cleared when P_1 = P_2 AND f_4 is set.

The bugs are fixed respectively by:
1. Clearing f_2 when x = y = b = 0 OR a = b = y = 0 using an XGate controlled on those
registers.
2. Moving the CModSub and CModAdd bloqs before the Equals bloq.

Args:
n: The bitsize of the two registers storing the elliptic curve point
mod: The modulus of the field in which we do the addition.
Expand Down Expand Up @@ -863,6 +917,11 @@ def build_composite_bloq(
f3 = f_ctrls[1]
f4 = f_ctrls[2]

# Unset f2 if ((a, b) = (0, 0) AND y = 0) OR ((x, y) = (0, 0) AND b = 0).
mcx = XGate().controlled(CtrlSpec(qdtypes=QMontgomeryUInt(self.n), cvs=[0, 0, 0]))
[a, b, y], f2 = bb.add(mcx, ctrl=[a, b, y], q=f2)
[x, y, b], f2 = bb.add(mcx, ctrl=[x, y, b], q=f2)

# Set (x, y) to (a, b) if f4 is set.
a_split = bb.split(a)
x_split = bb.split(x)
Expand All @@ -883,24 +942,6 @@ def build_composite_bloq(
b = bb.join(b_split, QMontgomeryUInt(self.n))
y = bb.join(y_split, QMontgomeryUInt(self.n))

# Unset f4 if (x, y) = (a, b).
ab = bb.join(np.concatenate([bb.split(a), bb.split(b)]), dtype=QMontgomeryUInt(2 * self.n))
xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n))
ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4)
ab_split = bb.split(ab)
a = bb.join(ab_split[: int(self.n)], dtype=QMontgomeryUInt(self.n))
b = bb.join(ab_split[int(self.n) :], dtype=QMontgomeryUInt(self.n))
xy_split = bb.split(xy)
x = bb.join(xy_split[: int(self.n)], dtype=QMontgomeryUInt(self.n))
y = bb.join(xy_split[int(self.n) :], dtype=QMontgomeryUInt(self.n))

# Unset f3 if (a, b) = (0, 0).
ab_arr = np.concatenate([bb.split(a), bb.split(b)])
ab_arr, f3 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=ab_arr, target=f3)
ab_arr = np.split(ab_arr, 2)
a = bb.join(ab_arr[0], dtype=QMontgomeryUInt(self.n))
b = bb.join(ab_arr[1], dtype=QMontgomeryUInt(self.n))

# If f1 and f2 are set, subtract a from x and add b to y.
ancilla = bb.add(ZeroState())
toff_ctrl = [f1, f2]
Expand All @@ -923,6 +964,24 @@ def build_composite_bloq(
f2 = toff_ctrl[1]
bb.add(Free(QBit()), reg=ancilla)

# Unset f4 if (x, y) = (a, b).
ab = bb.join(np.concatenate([bb.split(a), bb.split(b)]), dtype=QMontgomeryUInt(2 * self.n))
xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n))
ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4)
ab_split = bb.split(ab)
a = bb.join(ab_split[: int(self.n)], dtype=QMontgomeryUInt(self.n))
b = bb.join(ab_split[int(self.n) :], dtype=QMontgomeryUInt(self.n))
xy_split = bb.split(xy)
x = bb.join(xy_split[: int(self.n)], dtype=QMontgomeryUInt(self.n))
y = bb.join(xy_split[int(self.n) :], dtype=QMontgomeryUInt(self.n))

# Unset f3 if (a, b) = (0, 0).
ab_arr = np.concatenate([bb.split(a), bb.split(b)])
ab_arr, f3 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=ab_arr, target=f3)
ab_arr = np.split(ab_arr, 2)
a = bb.join(ab_arr[0], dtype=QMontgomeryUInt(self.n))
b = bb.join(ab_arr[1], dtype=QMontgomeryUInt(self.n))

# Unset f1 and f2 if (x, y) = (0, 0).
xy_arr = np.concatenate([bb.split(x), bb.split(y)])
xy_arr, junk, out = bb.add(MultiAnd(cvs=[0] * 2 * self.n), ctrl=xy_arr)
Expand All @@ -939,33 +998,32 @@ def build_composite_bloq(
y = bb.join(xy_arr[1], dtype=QMontgomeryUInt(self.n))

# Free all ancilla qubits in the zero state.
# TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bugs in circuit where f1,
# f2, and f4 are freed before being set to 0.
bb.add(Free(QBit(), dirty=True), reg=f1)
bb.add(Free(QBit(), dirty=True), reg=f2)
bb.add(Free(QBit()), reg=f1)
bb.add(Free(QBit()), reg=f2)
bb.add(Free(QBit()), reg=f3)
bb.add(Free(QBit(), dirty=True), reg=f4)
bb.add(Free(QBit()), reg=f4)
bb.add(Free(QBit()), reg=ctrl)

# Return the output registers.
return {'a': a, 'b': b, 'x': x, 'y': y}

def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
cvs: Union[list[int], HasLength]
cvs2: Union[list[int], HasLength]
if isinstance(self.n, int):
cvs = [0] * 2 * self.n
cvs2 = [0] * 2 * self.n
else:
cvs = HasLength(2 * self.n)
cvs2 = HasLength(2 * self.n)
return {
MultiControlX(cvs=cvs): 1,
MultiControlX(cvs=cvs2): 1,
XGate().controlled(CtrlSpec(qdtypes=QMontgomeryUInt(self.n), cvs=[0, 0, 0])): 2,
MultiControlX(cvs=[0] * 3): 1,
CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1,
CModAdd(QMontgomeryUInt(self.n), mod=self.mod): 1,
Toffoli(): 2 * self.n + 4,
Equals(QMontgomeryUInt(2 * self.n)): 1,
MultiAnd(cvs=cvs): 1,
MultiAnd(cvs=cvs2): 1,
MultiTargetCNOT(2): 1,
MultiAnd(cvs=cvs).adjoint(): 1,
MultiAnd(cvs=cvs2).adjoint(): 1,
}


Expand Down Expand Up @@ -1044,13 +1102,14 @@ def build_composite_bloq(
x, y, lam = bb.add(
_ECAddStepFour(n=self.n, mod=self.mod, window_size=self.window_size), x=x, y=y, lam=lam
)
ctrl, a, b, x, y = bb.add(
ctrl, a, b, x, y, lam_r = bb.add(
_ECAddStepFive(n=self.n, mod=self.mod, window_size=self.window_size),
ctrl=ctrl,
a=a,
b=b,
x=x,
y=y,
lam_r=lam_r,
lam=lam,
)
a, b, x, y = bb.add(
Expand Down
15 changes: 11 additions & 4 deletions qualtran/bloqs/cryptography/ecc/ec_add_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def test_ec_add_steps_classical_fast(n, m, a, b, x, y):
b=step_3['b'],
x=step_4['x'],
y=step_4['y'],
lam_r=step_2['lam_r'],
lam=step_4['lam'],
)
ret2 = bloq.decompose_bloq().call_classically(
Expand All @@ -119,6 +120,7 @@ def test_ec_add_steps_classical_fast(n, m, a, b, x, y):
b=step_3['b'],
x=step_4['x'],
y=step_4['y'],
lam_r=step_2['lam_r'],
lam=step_4['lam'],
)
assert ret1 == ret2
Expand All @@ -129,6 +131,7 @@ def test_ec_add_steps_classical_fast(n, m, a, b, x, y):
b=step_3['b'],
x=step_4['x'],
y=step_4['y'],
lam_r=step_2['lam_r'],
lam=step_4['lam'],
)
bloq = _ECAddStepSix(n=n, mod=p)
Expand Down Expand Up @@ -252,6 +255,7 @@ def test_ec_add_steps_classical(n, m, a, b, x, y):
b=step_3['b'],
x=step_4['x'],
y=step_4['y'],
lam_r=step_2['lam_r'],
lam=step_4['lam'],
)
ret2 = bloq.decompose_bloq().call_classically(
Expand All @@ -260,6 +264,7 @@ def test_ec_add_steps_classical(n, m, a, b, x, y):
b=step_3['b'],
x=step_4['x'],
y=step_4['y'],
lam_r=step_2['lam_r'],
lam=step_4['lam'],
)
assert ret1 == ret2
Expand All @@ -270,6 +275,7 @@ def test_ec_add_steps_classical(n, m, a, b, x, y):
b=step_3['b'],
x=step_4['x'],
y=step_4['y'],
lam_r=step_2['lam_r'],
lam=step_4['lam'],
)
bloq = _ECAddStepSix(n=n, mod=p)
Expand Down Expand Up @@ -417,12 +423,13 @@ def test_ec_add_symbolic_cost():

# Litinski 2023 https://arxiv.org/abs/2306.08585
# Based on the counts from Figures 3, 5, and 8 the toffoli count for ECAdd is 126.5n^2 + 189n.
# The following formula is 126.5n^2 + 195.5n - 31. We account for the discrepancy in the
# The following formula is 126.5n^2 + 215.5n - 34. We account for the discrepancy in the
# coefficient of n by a reduction in the toffoli cost of Montgomery ModMult, an increase in the
# toffoli cost for Kaliski Mod Inverse, n extra toffolis in ModNeg, 2n extra toffolis to do n
# 3-controlled toffolis in step 2. The expression is written with rationals because sympy
# comparison fails with floats.
assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(407, 2) * n - 31
# 3-controlled toffolis in step 2, and a few extra gates added to fix bugs found in the circuit
# (see class docstrings). The expression is written with rationals because sympy comparison
# fails with floats.
assert total_toff == sympy.Rational(253, 2) * n**2 + sympy.Rational(431, 2) * n - 34


def test_ec_add(bloq_autotester):
Expand Down