Skip to content

Commit

Permalink
Improve support for CCOs in cirq.merge_operations (quantumlib#5393)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanujkhattar authored May 26, 2022
1 parent c81339d commit 75e0f69
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 25 deletions.
14 changes: 7 additions & 7 deletions cirq/transformers/optimize_for_target_gateset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,17 @@ def test_optimize_for_target_gateset():
c_new,
'''
┌────────┐ ┌────────┐ ┌────────┐
0: ───M[1]──────────M[1]──────────────────────M[1]────Y['ignore']───M────────M[1]───────────────────────────M[1]────M[1]───M[1]───
│ │ │ ║ │ │
1: ───M[2]───M[1]───┼─────────────M[1]────M[1]┼───────Y['ignore']───M────────┼───M[1]───────────M[1]────M[1]┼───────┼──────M[2]───
0: ───M[1]──────────M[1]──────────────────────M[1]────Y['ignore']───M────────────M[1]───────────────────M[1]────────M[1]───M[1]───
│ │ │ ║ │ │
1: ───M[2]───M[1]───┼─────────────M[1]────M[1]┼───────Y['ignore']───M────────M[1]──────────────M[1]────┼───M[1]────┼──────M[2]───
│ │ │ │ │ ║ │ │ │ │ │ │
2: ──────────M[2]───M[2]───M[1]───┼───────M[2]┼───────@['ignore']───╫───@────┼───M[2]────M[1]───┼───────M[2]┼───────M[2]──────────
│ │ │ │ ║ ║ │ │
3: ────────────────────────M[2]───M[2]────────M[2]────X─────────────╫───@────M[2]────────M[2]───M[2]────────M[2]──────────────────
2: ──────────M[2]───M[2]───M[1]───┼───────M[2]┼───────@['ignore']───╫───@────M[2]───────M[1]───┼───────┼───M[2]────M[2]──────────
│ │ │ │ ║ ║ │ │ │
3: ────────────────────────M[2]───M[2]────────M[2]────X─────────────╫───@────────M[2]────M[2]───M[2]────M[2]──────────────────────
║ ║
m: ═════════════════════════════════════════════════════════════════@═══^═════════════════════════════════════════════════════════
└────────┘ └────────┘ └────────┘
''',
''',
)

with pytest.raises(ValueError, match="Unable to convert"):
Expand Down
127 changes: 109 additions & 18 deletions cirq/transformers/transformer_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,22 @@
"""Defines primitives for common transformer patterns."""

from collections import defaultdict
from typing import cast, Callable, Dict, Hashable, List, Optional, Sequence, Union, TYPE_CHECKING
import bisect
import dataclasses

from typing import (
cast,
Callable,
Dict,
Hashable,
List,
Optional,
Sequence,
Set,
Union,
Tuple,
TYPE_CHECKING,
)

from cirq import circuits, ops, protocols
from cirq.circuits.circuit import CIRCUIT_TYPE
Expand Down Expand Up @@ -186,6 +201,80 @@ def map_operations_and_unroll(
)


@dataclasses.dataclass
class _MergedCircuit:
"""An optimized internal representation of a circuit, tailored for `cirq.merge_operations`
Attributes:
qubit_indexes: Mapping from qubits to (sorted) list of moment indexes containing operations
acting on the qubit.
mkey_indexes: Mapping from measurement keys to (sorted) list of moment indexes containing
measurement operations with the same key.
ckey_indexes: Mapping from measurement keys to (sorted) list of moment indexes containing
classically controlled operations controlled on the same key.
ops_by_index: List of circuit moments containing operations. We use a dictionary instead
of a set to store operations to preserve insertion order.
"""

qubit_indexes: Dict['cirq.Qid', List[int]] = dataclasses.field(
default_factory=lambda: defaultdict(lambda: [-1])
)
mkey_indexes: Dict['cirq.MeasurementKey', List[int]] = dataclasses.field(
default_factory=lambda: defaultdict(lambda: [-1])
)
ckey_indexes: Dict['cirq.MeasurementKey', List[int]] = dataclasses.field(
default_factory=lambda: defaultdict(lambda: [-1])
)
ops_by_index: List[Dict['cirq.Operation', int]] = dataclasses.field(default_factory=list)

def append_empty_moment(self) -> None:
self.ops_by_index.append({})

def add_op_to_moment(self, moment_index: int, op: 'cirq.Operation') -> None:
self.ops_by_index[moment_index][op] = 0
for q in op.qubits:
if moment_index > self.qubit_indexes[q][-1]:
self.qubit_indexes[q].append(moment_index)
else:
bisect.insort(self.qubit_indexes[q], moment_index)
for mkey in protocols.measurement_key_objs(op):
bisect.insort(self.mkey_indexes[mkey], moment_index)
for ckey in protocols.control_keys(op):
bisect.insort(self.ckey_indexes[ckey], moment_index)

def remove_op_from_moment(self, moment_index: int, op: 'cirq.Operation') -> None:
self.ops_by_index[moment_index].pop(op)
for q in op.qubits:
if self.qubit_indexes[q][-1] == moment_index:
self.qubit_indexes[q].pop()
else:
self.qubit_indexes[q].remove(moment_index)
for mkey in protocols.measurement_key_objs(op):
self.mkey_indexes[mkey].remove(moment_index)
for ckey in protocols.control_keys(op):
self.ckey_indexes[ckey].remove(moment_index)

def get_mergeable_ops(
self, op: 'cirq.Operation', op_qs: Set['cirq.Qid']
) -> Tuple[int, List['cirq.Operation']]:
# Find the index of previous moment which can be merged with `op`.
idx = max([self.qubit_indexes[q][-1] for q in op_qs], default=-1)
idx = max([idx] + [self.mkey_indexes[ckey][-1] for ckey in protocols.control_keys(op)])
idx = max(
[idx] + [self.ckey_indexes[mkey][-1] for mkey in protocols.measurement_key_objs(op)]
)
# Return the set of overlapping ops in moment with index `idx`.
if idx == -1:
return idx, []

return idx, [
left_op for left_op in self.ops_by_index[idx] if not op_qs.isdisjoint(left_op.qubits)
]

def get_cirq_circuit(self) -> 'cirq.Circuit':
return circuits.Circuit(circuits.Moment(m.keys()) for m in self.ops_by_index)


def merge_operations(
circuit: CIRCUIT_TYPE,
merge_func: Callable[[ops.Operation, ops.Operation], Optional[ops.Operation]],
Expand Down Expand Up @@ -254,54 +343,56 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope
)
return new_op

ret_circuit = circuits.Circuit()
for current_moment in circuit:
new_moment = circuits.Moment()
merged_circuit = _MergedCircuit()
for moment_idx, current_moment in enumerate(cast(List['cirq.Moment'], circuit)):
merged_circuit.append_empty_moment()
for op in sorted(current_moment.operations, key=lambda op: op.qubits):
if (
deep
and isinstance(op.untagged, circuits.CircuitOperation)
and tags_to_ignore_set.isdisjoint(op.tags)
):
op_untagged = op.untagged
new_moment = new_moment.with_operation(
merged_circuit.add_op_to_moment(
moment_idx,
op_untagged.replace(
circuit=merge_operations(
op_untagged.circuit,
merge_func,
tags_to_ignore=tags_to_ignore,
deep=True,
)
).with_tags(*op.tags, _circuit_op_tag)
).with_tags(*op.tags, _circuit_op_tag),
)
continue

op_qs = set(op.qubits)
idx = ret_circuit.prev_moment_operating_on(tuple(op_qs))
if idx is not None and op_qs.issubset(ret_circuit[idx][op_qs].operations[0].qubits):
left_idx, left_ops = merged_circuit.get_mergeable_ops(op, op_qs)
if len(left_ops) == 1 and op_qs.issubset(left_ops[0].qubits):
# Case-1: Try to merge op with the larger operation on the left.
left_op = ret_circuit[idx][op_qs].operations[0]
new_op = apply_merge_func(left_op, op)
new_op = apply_merge_func(left_ops[0], op)
if new_op is not None:
ret_circuit.batch_replace([(idx, left_op, new_op)])
merged_circuit.remove_op_from_moment(left_idx, left_ops[0])
merged_circuit.add_op_to_moment(left_idx, new_op)
else:
new_moment = new_moment.with_operation(op)
merged_circuit.add_op_to_moment(moment_idx, op)
continue

while idx is not None and len(op_qs) > 0:
while left_ops and op_qs:
# Case-2: left_ops will merge right into `op` whenever possible.
for left_op in ret_circuit[idx][op_qs].operations:
for left_op in left_ops:
is_merged = False
if op_qs.issuperset(left_op.qubits):
# Try to merge left_op into op
new_op = apply_merge_func(left_op, op)
if new_op is not None:
ret_circuit.batch_remove([(idx, left_op)])
merged_circuit.remove_op_from_moment(left_idx, left_op)
op, is_merged = new_op, True
if not is_merged:
op_qs -= frozenset(left_op.qubits)
idx = ret_circuit.prev_moment_operating_on(tuple(op_qs))
new_moment = new_moment.with_operation(op)
ret_circuit += new_moment
left_idx, left_ops = merged_circuit.get_mergeable_ops(op, op_qs)
merged_circuit.add_op_to_moment(moment_idx, op)
ret_circuit = merged_circuit.get_cirq_circuit()
if deep:
ret_circuit = map_operations(
ret_circuit,
Expand Down
35 changes: 35 additions & 0 deletions cirq/transformers/transformer_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,3 +810,38 @@ def wrapped_merge_func(op1, op2):
_ = cirq.merge_operations(circuit, wrapped_merge_func)
total_operations = len([*circuit.all_operations()])
assert wrapped_merge_func.num_function_calls <= 2 * total_operations


def test_merge_operations_does_not_merge_ccos_behind_measurements():
q = cirq.LineQubit.range(2)
cco_op = cirq.X(q[1]).with_classical_controls("a")

def merge_func(op1, op2):
return cirq.I(*op1.qubits) if op1 == cco_op and op2 == cco_op else None

circuit = cirq.Circuit([cirq.H(q[0]), cirq.measure(q[0], key="a"), cco_op] * 2)
cirq.testing.assert_same_circuits(cirq.merge_operations(circuit, merge_func), circuit)

circuit = cirq.Circuit([cirq.H(q[0]), cirq.measure(q[0], key="a"), cco_op, cco_op] * 2)
expected_circuit = cirq.Circuit([cirq.H(q[0]), cirq.measure(q[0], key="a"), cirq.I(q[1])] * 2)
cirq.testing.assert_same_circuits(
cirq.align_left(cirq.merge_operations(circuit, merge_func)), expected_circuit
)


def test_merge_operations_does_not_merge_measurements_behind_ccos():
q = cirq.LineQubit.range(2)
measure_op = cirq.measure(q[0], key="a")
cco_op = cirq.X(q[1]).with_classical_controls("a")

def merge_func(op1, op2):
return cirq.I(*op1.qubits) if op1 == measure_op and op2 == measure_op else None

circuit = cirq.Circuit([cirq.H(q[0]), measure_op, cco_op] * 2)
cirq.testing.assert_same_circuits(cirq.merge_operations(circuit, merge_func), circuit)

circuit = cirq.Circuit([cirq.H(q[0]), measure_op, cco_op, measure_op, measure_op] * 2)
expected_circuit = cirq.Circuit([cirq.H(q[0]), measure_op, cco_op, cirq.I(q[0])] * 2)
cirq.testing.assert_same_circuits(
cirq.align_left(cirq.merge_operations(circuit, merge_func)), expected_circuit
)

0 comments on commit 75e0f69

Please sign in to comment.