Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add argument for repeating transformations to optimize_for_target_gateset #6426

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions cirq-core/cirq/transformers/optimize_for_target_gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Transformers to rewrite a circuit using gates from a given target gateset."""

from typing import Optional, Callable, Hashable, Sequence, TYPE_CHECKING
from typing import Optional, Callable, Hashable, Sequence, TYPE_CHECKING, Union

from cirq import circuits
from cirq.protocols import decompose_protocol as dp
Expand Down Expand Up @@ -102,20 +102,26 @@ def optimize_for_target_gateset(
context: Optional['cirq.TransformerContext'] = None,
gateset: Optional['cirq.CompilationTargetGateset'] = None,
ignore_failures: bool = True,
max_num_passes: Union[int, None] = 1,
) -> 'cirq.Circuit':
"""Transforms the given circuit into an equivalent circuit using gates accepted by `gateset`.

Repeat max_num_passes times or when `max_num_passes=None` until no further changes can be done
1. Run all `gateset.preprocess_transformers`
2. Convert operations using built-in cirq decompose + `gateset.decompose_to_target_gateset`.
3. Run all `gateset.postprocess_transformers`

Note:
The optimizer is heuristic and may not produce optimal results even with max_num_passes=None.

Args:
circuit: Input circuit to transform. It will not be modified.
context: `cirq.TransformerContext` storing common configurable options for transformers.
gateset: Target gateset, which should be an instance of `cirq.CompilationTargetGateset`.
ignore_failures: If set, operations that fail to convert are left unchanged. If not set,
conversion failures raise a ValueError.

max_num_passes: The maximum number of passes to do. A value of `None` means keep iterating
until no further improvements can be done.
Returns:
An equivalent circuit containing gates accepted by `gateset`.

Expand All @@ -126,20 +132,38 @@ def optimize_for_target_gateset(
return _decompose_operations_to_target_gateset(
circuit, context=context, ignore_failures=ignore_failures
)
if isinstance(max_num_passes, int):
_outerloop = lambda: range(max_num_passes)
else:

def _outerloop():
while True:
yield 0

initial_num_moments, initial_num_ops = len(circuit), len(tuple(circuit.all_operations()))
for _ in _outerloop():
for transformer in gateset.preprocess_transformers:
circuit = transformer(circuit, context=context)

circuit = _decompose_operations_to_target_gateset(
circuit,
context=context,
gateset=gateset,
decomposer=gateset.decompose_to_target_gateset,
ignore_failures=ignore_failures,
tags_to_decompose=(gateset._intermediate_result_tag,),
)

for transformer in gateset.preprocess_transformers:
circuit = transformer(circuit, context=context)

circuit = _decompose_operations_to_target_gateset(
circuit,
context=context,
gateset=gateset,
decomposer=gateset.decompose_to_target_gateset,
ignore_failures=ignore_failures,
tags_to_decompose=(gateset._intermediate_result_tag,),
)
for transformer in gateset.postprocess_transformers:
circuit = transformer(circuit, context=context)

for transformer in gateset.postprocess_transformers:
circuit = transformer(circuit, context=context)
circuit = circuits.Circuit(
op for op in circuit.all_operations()
) # Ensure the circuit is contracted.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This leads to a loss of moment structure. When we wrote this function originally, preserving moment structure was an important requirement. Are you sure we want to do this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was the original requirement exactly? the circuit in the issue is supposed to endup with 3 moments as per @eliottrosenberg which can't happen without some sort of contraction. probably there is a way to do the contraction without destroying the moment structure. but we need to clarify what it's exactly that we want to preserve or if we simply need another preserve_moment_structure parameter

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think reducing the number of moments is generally a good thing. In general, this can't preserve the number of moments (i.e. if I ask it to transform a cirq.ISWAP gate into CZ).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tanujkhattar apart from this line. the current implementation doesn't preserve the moment structure, for example look at the output in the issue. CZ(4, 5) is earlier than it should and PhXZ(a=0,x=1,z=0)(6)───PhXZ(a=0.5,x=-0.5,z=1)(6)-CZ(5, 6) are a lot later than they should.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's continue the discussion on the original issue here - #6422 (comment)


num_moments, num_ops = len(circuit), len(tuple(circuit.all_operations()))
if (num_moments, num_ops) == (initial_num_moments, initial_num_ops):
# Stop early. No further optimizations can be done.
break
initial_num_moments, initial_num_ops = num_moments, num_ops
return circuit.unfreeze(copy=False)
72 changes: 72 additions & 0 deletions cirq-core/cirq/transformers/optimize_for_target_gateset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import cirq
from cirq.protocols.decompose_protocol import DecomposeResult
from cirq.transformers.optimize_for_target_gateset import _decompose_operations_to_target_gateset
Expand Down Expand Up @@ -243,3 +245,73 @@ def test_optimize_for_target_gateset_deep():
1: ───#2───────────────────────────────────────────────────────────────────────────
''',
)


@pytest.mark.parametrize('max_num_passes', [2, None])
def test_optimize_for_target_gateset_multiple_passes(max_num_passes: Union[int, None]):
gateset = cirq.CZTargetGateset()

input_circuit = cirq.Circuit(
[
cirq.Moment(
cirq.X(cirq.LineQubit(1)),
cirq.X(cirq.LineQubit(2)),
cirq.X(cirq.LineQubit(3)),
cirq.X(cirq.LineQubit(6)),
),
cirq.Moment(
cirq.H(cirq.LineQubit(0)),
cirq.H(cirq.LineQubit(1)),
cirq.H(cirq.LineQubit(2)),
cirq.H(cirq.LineQubit(3)),
cirq.H(cirq.LineQubit(4)),
cirq.H(cirq.LineQubit(5)),
cirq.H(cirq.LineQubit(6)),
),
cirq.Moment(
cirq.H(cirq.LineQubit(1)), cirq.H(cirq.LineQubit(3)), cirq.H(cirq.LineQubit(5))
),
cirq.Moment(
cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)),
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5)),
),
cirq.Moment(
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)),
cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5)),
),
]
)

desired_circuit = cirq.Circuit(
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on(
cirq.LineQubit(4)
),
cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on(
cirq.LineQubit(1)
),
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on(
cirq.LineQubit(2)
),
cirq.PhasedXZGate(axis_phase_exponent=0.5, x_exponent=-0.5, z_exponent=1.0).on(
cirq.LineQubit(0)
),
cirq.PhasedXZGate(axis_phase_exponent=-1.0, x_exponent=1, z_exponent=0).on(
cirq.LineQubit(3)
),
cirq.PhasedXZGate(axis_phase_exponent=-0.5, x_exponent=0.5, z_exponent=0.0).on(
cirq.LineQubit(6)
),
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(5)),
cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(3)),
cirq.CZ(cirq.LineQubit(2), cirq.LineQubit(1)),
cirq.CZ(cirq.LineQubit(4), cirq.LineQubit(3)),
cirq.CZ(cirq.LineQubit(6), cirq.LineQubit(5)),
)

got = cirq.optimize_for_target_gateset(
input_circuit, gateset=gateset, max_num_passes=max_num_passes
)
assert got == desired_circuit
Loading