Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add argument for repeating transformations to optimize_for_target_gateset #6426

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions cirq-core/cirq/transformers/optimize_for_target_gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Transformers to rewrite a circuit using gates from a given target gateset."""

from typing import Optional, Callable, Hashable, Sequence, TYPE_CHECKING
from typing import Optional, Callable, Hashable, Sequence, TYPE_CHECKING, Union

from cirq import circuits
from cirq.protocols import decompose_protocol as dp
Expand Down Expand Up @@ -102,20 +102,29 @@ def optimize_for_target_gateset(
context: Optional['cirq.TransformerContext'] = None,
gateset: Optional['cirq.CompilationTargetGateset'] = None,
ignore_failures: bool = True,
max_num_passes: Union[int, None] = 1,
) -> 'cirq.Circuit':
"""Transforms the given circuit into an equivalent circuit using gates accepted by `gateset`.

Repeat max_num_passes times or when `max_num_passes=None` until no further changes can be done
1. Run all `gateset.preprocess_transformers`
2. Convert operations using built-in cirq decompose + `gateset.decompose_to_target_gateset`.
3. Run all `gateset.postprocess_transformers`

Note:
The optimizer is a heuristic and may not produce optimal results even with
max_num_passes=None. The prerprocessors and postprocessors of the gate set
as well as their order yield different results.


Args:
circuit: Input circuit to transform. It will not be modified.
context: `cirq.TransformerContext` storing common configurable options for transformers.
gateset: Target gateset, which should be an instance of `cirq.CompilationTargetGateset`.
ignore_failures: If set, operations that fail to convert are left unchanged. If not set,
conversion failures raise a ValueError.

max_num_passes: The maximum number of passes to do. A value of `None` means to keep
iterating until no further improvements can be made.
Returns:
An equivalent circuit containing gates accepted by `gateset`.

Expand All @@ -126,20 +135,32 @@ def optimize_for_target_gateset(
return _decompose_operations_to_target_gateset(
circuit, context=context, ignore_failures=ignore_failures
)

for transformer in gateset.preprocess_transformers:
circuit = transformer(circuit, context=context)

circuit = _decompose_operations_to_target_gateset(
circuit,
context=context,
gateset=gateset,
decomposer=gateset.decompose_to_target_gateset,
ignore_failures=ignore_failures,
tags_to_decompose=(gateset._intermediate_result_tag,),
)

for transformer in gateset.postprocess_transformers:
circuit = transformer(circuit, context=context)

if isinstance(max_num_passes, int):
_outerloop = lambda: range(max_num_passes)
else:

def _outerloop():
while True:
yield 0

initial_num_moments, initial_num_ops = len(circuit), len(tuple(circuit.all_operations()))
for _ in _outerloop():
for transformer in gateset.preprocess_transformers:
circuit = transformer(circuit, context=context)
circuit = _decompose_operations_to_target_gateset(
circuit,
context=context,
gateset=gateset,
decomposer=gateset.decompose_to_target_gateset,
ignore_failures=ignore_failures,
tags_to_decompose=(gateset._intermediate_result_tag,),
)
for transformer in gateset.postprocess_transformers:
circuit = transformer(circuit, context=context)

num_moments, num_ops = len(circuit), len(tuple(circuit.all_operations()))
if (num_moments, num_ops) == (initial_num_moments, initial_num_ops):
# Stop early. No further optimizations can be done.
break
initial_num_moments, initial_num_ops = num_moments, num_ops
return circuit.unfreeze(copy=False)
81 changes: 81 additions & 0 deletions cirq-core/cirq/transformers/optimize_for_target_gateset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import cirq
from cirq.protocols.decompose_protocol import DecomposeResult
from cirq.transformers.optimize_for_target_gateset import _decompose_operations_to_target_gateset
Expand Down Expand Up @@ -243,3 +245,82 @@ def test_optimize_for_target_gateset_deep():
1: ───#2───────────────────────────────────────────────────────────────────────────
''',
)


@pytest.mark.parametrize('max_num_passes', [2, None])
def test_optimize_for_target_gateset_multiple_passes(max_num_passes: Union[int, None]):
gateset = cirq.CZTargetGateset()

input_circuit = cirq.Circuit(
[
cirq.Moment(
cirq.X(cirq.LineQubit(1)),
cirq.X(cirq.LineQubit(2)),
cirq.X(cirq.LineQubit(3)),
cirq.X(cirq.LineQubit(6)),
),
cirq.Moment(
cirq.H(cirq.LineQubit(0)),
cirq.H(cirq.LineQubit(1)),
cirq.H(cirq.LineQubit(2)),
cirq.H(cirq.LineQubit(3)),
cirq.H(cirq.LineQubit(4)),
cirq.H(cirq.LineQubit(5)),
cirq.H(cirq.LineQubit(6)),
),
cirq.Moment(
cirq.H(cirq.LineQubit(1)), cirq.H(cirq.LineQubit(3)), cirq.H(cirq.LineQubit(5))
),
cirq.Moment(
cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)),
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5)),
),
cirq.Moment(
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)),
cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5)),
),
]
)

desired_circuit = cirq.Circuit.from_moments(
cirq.Moment(
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on(
cirq.LineQubit(4)
)
),
cirq.Moment(cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5))),
cirq.Moment(
cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on(
cirq.LineQubit(1)
),
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on(
cirq.LineQubit(0)
),
cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on(
cirq.LineQubit(3)
),
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on(
cirq.LineQubit(2)
),
),
cirq.Moment(
cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)),
),
cirq.Moment(
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)),
),
cirq.Moment(
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on(
cirq.LineQubit(6)
)
),
cirq.Moment(cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5))),
)
got = cirq.optimize_for_target_gateset(
input_circuit, gateset=gateset, max_num_passes=max_num_passes
)
assert got == desired_circuit
Loading