Skip to content

Commit

Permalink
Add merge_operations and merge_moments transformer primitives (qu…
Browse files Browse the repository at this point in the history
…antumlib#4707)

* merge_operations and merge_moments transformer primitives

* Refactor to use features compatible with python3.6

* Add complexity tests for merge_operations

* Add iteration info to the docstring
  • Loading branch information
tanujkhattar authored and Nate Thompson committed Dec 11, 2021
1 parent e97768c commit 5e468de
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@
map_moments,
map_operations,
map_operations_and_unroll,
merge_moments,
merge_operations,
merge_single_qubit_gates_into_phased_x_z,
merge_single_qubit_gates_into_phxz,
MergeInteractions,
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@
map_moments,
map_operations,
map_operations_and_unroll,
merge_moments,
merge_operations,
unroll_circuit_op,
unroll_circuit_op_greedy_earliest,
unroll_circuit_op_greedy_frontier,
Expand Down
117 changes: 117 additions & 0 deletions cirq-core/cirq/optimizers/transformer_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Callable,
Dict,
Hashable,
List,
Optional,
Sequence,
TYPE_CHECKING,
Expand Down Expand Up @@ -129,6 +130,122 @@ def map_operations_and_unroll(
return unroll_circuit_op(map_operations(circuit, map_func))


def merge_operations(
circuit: CIRCUIT_TYPE,
merge_func: Callable[[ops.Operation, ops.Operation], Optional[ops.Operation]],
) -> CIRCUIT_TYPE:
"""Merges operations in a circuit by calling `merge_func` iteratively on operations.
Two operations op1 and op2 are merge-able if
- There is no other operations between op1 and op2 in the circuit
- is_subset(op1.qubits, op2.qubits) or is_subset(op2.qubits, op1.qubits)
The `merge_func` is a callable which, given two merge-able operations
op1 and op2, decides whether they should be merged into a single operation
or not. If not, it should return None, else it should return the single merged
operations `op`.
The method iterates on the input circuit moment-by-moment from left to right and attempts
to repeatedly merge each operation in the latest moment with all the corresponding merge-able
operations to it's left.
If op1 and op2 are merged, both op1 and op2 are deleted from the circuit and
the resulting `merged_op` is inserted at the index corresponding to the larger
of op1/op2. If both op1 and op2 act on the same number of qubits, `merged_op` is
inserted in the smaller moment index to minimize circuit depth.
The number of calls to `merge_func` is O(N), where N = Total no. of operations, because:
- Every time the `merge_func` returns a new operation, the number of operations in the
circuit reduce by 1 and hence this can happen at most O(N) times
- Every time the `merge_func` returns None, the current operation is inserted into the
frontier and we go on to process the next operation, which can also happen at-most
O(N) times.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
merge_func: Callable to determine whether two merge-able operations in the circuit should
be merged. If the operations can be merged, the callable should return the merged
operation, else None.
Returns:
Copy of input circuit with merged operations.
Raises:
ValueError if the merged operation acts on new qubits outside the set of qubits
corresponding to the original operations to be merged.
"""

def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Operation]:
new_op = merge_func(op1, op2)
qubit_set = frozenset(op1.qubits + op2.qubits)
if new_op is not None and not qubit_set.issuperset(new_op.qubits):
raise ValueError(
f"Merged operation {new_op} must act on a subset of qubits of "
f"original operations {op1} and {op2}"
)
return new_op

ret_circuit = circuits.Circuit()
for current_moment in circuit:
new_moment = ops.Moment()
for op in current_moment:
op_qs = set(op.qubits)
idx = ret_circuit.prev_moment_operating_on(tuple(op_qs))
if idx is not None and op_qs.issubset(ret_circuit[idx][op_qs].operations[0].qubits):
# Case-1: Try to merge op with the larger operation on the left.
left_op = ret_circuit[idx][op_qs].operations[0]
new_op = apply_merge_func(left_op, op)
if new_op is not None:
ret_circuit.batch_replace([(idx, left_op, new_op)])
else:
new_moment = new_moment.with_operation(op)
continue

while idx is not None and len(op_qs) > 0:
# Case-2: left_ops will merge right into `op` whenever possible.
for left_op in ret_circuit[idx][op_qs].operations:
is_merged = False
if op_qs.issuperset(left_op.qubits):
# Try to merge left_op into op
new_op = apply_merge_func(left_op, op)
if new_op is not None:
ret_circuit.batch_remove([(idx, left_op)])
op, is_merged = new_op, True
if not is_merged:
op_qs -= frozenset(left_op.qubits)
idx = ret_circuit.prev_moment_operating_on(tuple(op_qs))
new_moment = new_moment.with_operation(op)
ret_circuit += new_moment
return _to_target_circuit_type(ret_circuit, circuit)


def merge_moments(
circuit: CIRCUIT_TYPE,
merge_func: Callable[[ops.Moment, ops.Moment], Optional[ops.Moment]],
) -> CIRCUIT_TYPE:
"""Merges adjacent moments, one by one from left to right, by calling `merge_func(m1, m2)`.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
merge_func: Callable to determine whether two adjacent moments in the circuit should be
merged. If the moments can be merged, the callable should return the merged moment,
else None.
Returns:
Copy of input circuit with merged moments.
"""
if not circuit:
return circuit
merged_moments: List[ops.Moment] = [circuit[0]]
for current_moment in circuit[1:]:
merged_moment = merge_func(merged_moments[-1], current_moment)
if not merged_moment:
merged_moments.append(current_moment)
else:
merged_moments[-1] = merged_moment
return _create_target_circuit_type(merged_moments, circuit)


def _check_circuit_op(op, tags_to_check: Optional[Sequence[Hashable]]):
return isinstance(op.untagged, circuits.CircuitOperation) and (
tags_to_check is None or any(tag in op.tags for tag in tags_to_check)
Expand Down
146 changes: 146 additions & 0 deletions cirq-core/cirq/optimizers/transformer_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
import pytest

import cirq
from cirq.optimizers.transformer_primitives import MAPPED_CIRCUIT_OP_TAG

Expand Down Expand Up @@ -198,3 +200,147 @@ def test_map_moments_drop_empty_moments():
c = cirq.Circuit(cirq.Moment(op), cirq.Moment(), cirq.Moment(op))
c_mapped = cirq.map_moments(c, lambda m, i: [] if len(m) == 0 else [m])
cirq.testing.assert_same_circuits(c_mapped, cirq.Circuit(c[0], c[0]))


def test_merge_moments():
q = cirq.LineQubit.range(3)
c_orig = cirq.Circuit(
cirq.Z.on_each(q[0], q[1]),
cirq.Z.on_each(q[1], q[2]),
cirq.Z.on_each(q[1], q[0]),
strategy=cirq.InsertStrategy.NEW_THEN_INLINE,
)
c_orig = cirq.Circuit(c_orig, cirq.CCX(*q), c_orig)
cirq.testing.assert_has_diagram(
c_orig,
'''
0: ───Z───────Z───@───Z───────Z───
1: ───Z───Z───Z───@───Z───Z───Z───
2: ───────Z───────X───────Z───────
''',
)

def merge_func(m1: cirq.Moment, m2: cirq.Moment) -> Optional[cirq.Moment]:
def is_z_moment(m):
return all(op.gate == cirq.Z for op in m)

if not (is_z_moment(m1) and is_z_moment(m2)):
return None
qubits = m1.qubits | m2.qubits

def mul(op1, op2):
return (op1 or op2) if not (op1 and op2) else cirq.decompose_once(op1 * op2)

return cirq.Moment(mul(m1.operation_at(q), m2.operation_at(q)) for q in qubits)

cirq.testing.assert_has_diagram(
cirq.merge_moments(c_orig, merge_func),
'''
0: ───────@───────
1: ───Z───@───Z───
2: ───Z───X───Z───
''',
)


def test_merge_moments_empty_circuit():
def fail_if_called_func(*_):
assert False

c = cirq.Circuit()
assert cirq.merge_moments(c, fail_if_called_func) is c


def test_merge_operations_raises():
q = cirq.LineQubit.range(3)
c = cirq.Circuit(cirq.CZ(*q[:2]), cirq.X(q[0]))
with pytest.raises(ValueError, match='must act on a subset of qubits'):
cirq.merge_operations(c, lambda *_: cirq.X(q[2]))


def test_merge_operations_nothing_to_merge():
def fail_if_called_func(*_):
assert False

# Empty Circuit.
c = cirq.Circuit()
assert cirq.merge_operations(c, fail_if_called_func) == c
# Single moment
q = cirq.LineQubit.range(3)
c += cirq.Moment(cirq.CZ(*q[:2]))
assert cirq.merge_operations(c, fail_if_called_func) == c
# Multi moment with disjoint operations + global phase operation.
c += cirq.Moment(cirq.X(q[2]), cirq.GlobalPhaseOperation(1j))
assert cirq.merge_operations(c, fail_if_called_func) == c


def test_merge_operations_merges_connected_component():
q = cirq.LineQubit.range(3)
c_orig = cirq.Circuit(
cirq.Moment(cirq.H.on_each(*q)),
cirq.CNOT(q[0], q[2]),
cirq.CNOT(*q[0:2]),
cirq.H(q[0]),
cirq.CZ(*q[:2]),
cirq.X(q[0]),
cirq.Y(q[1]),
cirq.CNOT(*q[0:2]),
cirq.CNOT(*q[1:3]),
cirq.X(q[0]),
cirq.Y(q[1]),
cirq.CNOT(*q[:2]),
strategy=cirq.InsertStrategy.NEW,
)
cirq.testing.assert_has_diagram(
c_orig,
'''
0: ───H───@───@───H───@───X───────@───────X───────@───
│ │ │ │ │
1: ───H───┼───X───────@───────Y───X───@───────Y───X───
│ │
2: ───H───X───────────────────────────X───────────────
''',
)

def merge_func(op1, op2):
"""Artificial example where a CZ will absorb any merge-able operation."""
for op in [op1, op2]:
if op.gate == cirq.CZ:
return op
return None

c_new = cirq.merge_operations(c_orig, merge_func)
cirq.testing.assert_has_diagram(
c_new,
'''
0: ───H───@───────────@───────────────────────────@───
│ │ │
1: ───────┼───────────@───────────────@───────Y───X───
│ │
2: ───H───X───────────────────────────X───────────────''',
)


@pytest.mark.parametrize("op_density", [0.1, 0.5, 0.9])
def test_merge_operations_complexity(op_density):
prng = cirq.value.parse_random_state(11011)
circuit = cirq.testing.random_circuit(20, 500, op_density, random_state=prng)
for merge_func in [
lambda _, __: None,
lambda op1, _: op1,
lambda _, op2: op2,
lambda op1, op2: prng.choice([op1, op2, None]),
]:

def wrapped_merge_func(op1, op2):
wrapped_merge_func.num_function_calls += 1
return merge_func(op1, op2)

wrapped_merge_func.num_function_calls = 0
_ = cirq.merge_operations(circuit, wrapped_merge_func)
total_operations = len([*circuit.all_operations()])
assert wrapped_merge_func.num_function_calls <= 2 * total_operations

0 comments on commit 5e468de

Please sign in to comment.