Skip to content

Commit

Permalink
Add cirq.CZTargetGateset to replace `cirq.ConvertToCZAndSingleGates…
Browse files Browse the repository at this point in the history
…` and `cirq.MergeInteractions` (quantumlib#5007)

* Add CZTargetGateset to replace ConvertToCZAndSingleGates and MergeInteractions

* Address feedback and add serialization support for CZTargetGateset.

* Flatten new optree before computing new_2q_gate_count so it works when op tree is a generator.
  • Loading branch information
tanujkhattar authored Feb 25, 2022
1 parent a0d0b9b commit 627e12b
Show file tree
Hide file tree
Showing 18 changed files with 746 additions and 39 deletions.
2 changes: 2 additions & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@
align_left,
align_right,
CompilationTargetGateset,
CZTargetGateset,
compute_cphase_exponents_for_fsim_decomposition,
decompose_clifford_tableau_to_operations,
decompose_cphase_into_two_fsim,
Expand Down Expand Up @@ -401,6 +402,7 @@
two_qubit_matrix_to_operations,
two_qubit_matrix_to_sqrt_iswap_operations,
two_qubit_gate_product_tabulation,
TwoQubitCompilationTargetGateset,
TwoQubitGateTabulation,
TwoQubitGateTabulationResult,
toggle_tags,
Expand Down
7 changes: 4 additions & 3 deletions cirq/contrib/paulistring/convert_gate_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from cirq import circuits, optimizers, transformers
from cirq import circuits, transformers

from cirq.contrib.paulistring.convert_to_pauli_string_phasors import ConvertToPauliStringPhasors

Expand All @@ -26,8 +26,9 @@ def converted_gate_set(
{SingleQubitCliffordGate,
CZ/PauliInteractionGate, PauliStringPhasor}.
"""
conv_circuit = circuits.Circuit(circuit)
optimizers.ConvertToCzAndSingleGates().optimize_circuit(conv_circuit)
conv_circuit = transformers.optimize_for_target_gateset(
circuit, gateset=transformers.CZTargetGateset()
)
conv_circuit = transformers.merge_k_qubit_unitaries(conv_circuit, k=1)
ConvertToPauliStringPhasors(
ignore_failures=True,
Expand Down
6 changes: 3 additions & 3 deletions cirq/contrib/paulistring/convert_gate_set_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def test_converts_large_circuit():
'''
0: ───Y^0.5───@───[Z]^-0.304───[X]^(1/3)───[Z]^0.446───@───
│ │
1: ───────────@───@────────────────────────────────────@───
2: ───────────────@────────────────────────────────────────
1: ───────────@────────────────@───────────────────────@───
2: ────────────────────────────@───────────────────────────
''',
)
35 changes: 25 additions & 10 deletions cirq/contrib/paulistring/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Sequence
from typing import Callable

from cirq import ops, circuits, optimizers
from cirq import ops, circuits, transformers
from cirq.contrib.paulistring.pauli_string_optimize import pauli_string_optimized_circuit
from cirq.contrib.paulistring.clifford_optimize import clifford_optimized_circuit


class _CZTargetGateSet(transformers.CZTargetGateset):
"""Private implementation of `cirq.CZTargetGateset` used for optimized_circuit method below.
The implementation extends `cirq.CZTargetGateset` by modifying decomposed operations using
`post_clean_up` before putting them back in the circuit.
"""

def __init__(
self,
post_clean_up: Callable[[ops.OP_TREE], ops.OP_TREE] = lambda op_tree: op_tree,
):
super().__init__()
self.post_clean_up = post_clean_up

def _decompose_two_qubit_operation(self, op: ops.Operation, _) -> ops.OP_TREE:
ret = super()._decompose_two_qubit_operation(op, _)
return ret if ret is NotImplemented else self.post_clean_up(ret)


def optimized_circuit(
circuit: circuits.Circuit, atol: float = 1e-8, repeat: int = 10, merge_interactions: bool = True
) -> circuits.Circuit:
circuit = circuits.Circuit(circuit) # Make a copy
gateset = _CZTargetGateSet(post_clean_up=_optimized_ops)
for _ in range(repeat):
start_len = len(circuit)
start_cz_count = _cz_count(circuit)
if merge_interactions:
optimizers.MergeInteractions(
allow_partial_czs=False,
post_clean_up=_optimized_ops,
).optimize_circuit(circuit)
circuit = transformers.optimize_for_target_gateset(circuit, gateset=gateset)
circuit2 = pauli_string_optimized_circuit(circuit, move_cliffords=False, atol=atol)
circuit3 = clifford_optimized_circuit(circuit2, atol=atol)
if len(circuit3) == start_len and _cz_count(circuit3) == start_cz_count:
Expand All @@ -39,12 +56,10 @@ def optimized_circuit(
return circuit


def _optimized_ops(
ops: Sequence[ops.Operation], atol: float = 1e-8, repeat: int = 10
) -> ops.OP_TREE:
def _optimized_ops(ops: ops.OP_TREE, atol: float = 1e-8, repeat: int = 10) -> ops.OP_TREE:
c = circuits.Circuit(ops)
c_opt = optimized_circuit(c, atol, repeat, merge_interactions=False)
return c_opt.all_operations()
return [*c_opt.all_operations()]


def _cz_count(circuit):
Expand Down
1 change: 1 addition & 0 deletions cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _parallel_gate_op(gate, qubits):
'CSwapGate': cirq.CSwapGate,
'CXPowGate': cirq.CXPowGate,
'CZPowGate': cirq.CZPowGate,
'CZTargetGateset': cirq.CZTargetGateset,
'DensePauliString': cirq.DensePauliString,
'DepolarizingChannel': cirq.DepolarizingChannel,
'DeviceMetadata': cirq.DeviceMetadata,
Expand Down
5 changes: 4 additions & 1 deletion cirq/optimizers/convert_to_cz_and_single_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@

from typing import Optional

from cirq import circuits, ops, protocols
from cirq import circuits, ops, protocols, _compat
from cirq.transformers.analytical_decompositions import two_qubit_to_cz


@_compat.deprecated_class(
deadline='v1.0', fix='Use cirq.optimize_for_target_gateset and cirq.CZTargetGateset instead.'
)
class ConvertToCzAndSingleGates(circuits.PointOptimizer):
"""Attempts to convert strange multi-qubit gates into CZ and single qubit
gates.
Expand Down
32 changes: 22 additions & 10 deletions cirq/optimizers/convert_to_cz_and_single_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,17 @@ def _decompose_(self, qubits):

a, b = cirq.LineQubit.range(2)
c = cirq.Circuit(OtherXX()(a, b), OtherOtherXX()(a, b))
cirq.ConvertToCzAndSingleGates().optimize_circuit(c)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
cirq.ConvertToCzAndSingleGates().optimize_circuit(c)
assert len(c) == 2


def test_kak_decomposes_unknown_two_qubit_gate():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(cirq.ISWAP(q0, q1))
c_orig = cirq.Circuit(circuit)
cirq.ConvertToCzAndSingleGates().optimize_circuit(circuit)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
cirq.ConvertToCzAndSingleGates().optimize_circuit(circuit)

assert sum(1 for op in circuit.all_operations() if len(op.qubits) > 1) == 2
assert sum(1 for op in circuit.all_operations() if isinstance(op.gate, cirq.CZPowGate)) == 2
Expand Down Expand Up @@ -84,7 +86,8 @@ def _decompose_(self, qubits):
cirq.Y(q1) ** 0.5,
)
c_orig = cirq.Circuit(circuit)
cirq.ConvertToCzAndSingleGates().optimize_circuit(circuit)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
cirq.ConvertToCzAndSingleGates().optimize_circuit(circuit)

cirq.testing.assert_allclose_up_to_global_phase(
circuit.unitary(), expected.unitary(), atol=1e-7
Expand All @@ -101,7 +104,8 @@ class UnsupportedDummy(cirq.testing.TwoQubitGate):
UnsupportedDummy()(q0, q1),
)
c_orig = cirq.Circuit(circuit)
cirq.ConvertToCzAndSingleGates(ignore_failures=True).optimize_circuit(circuit)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
cirq.ConvertToCzAndSingleGates(ignore_failures=True).optimize_circuit(circuit)

assert circuit == c_orig

Expand All @@ -115,7 +119,10 @@ class UnsupportedDummy(cirq.testing.TwoQubitGate):
UnsupportedDummy()(q0, q1),
)
with pytest.raises(TypeError):
cirq.ConvertToCzAndSingleGates().optimize_circuit(circuit)
with cirq.testing.assert_deprecated(
"Use cirq.optimize_for_target_gateset", deadline='v1.0'
):
cirq.ConvertToCzAndSingleGates().optimize_circuit(circuit)


def test_passes_through_measurements():
Expand All @@ -125,7 +132,8 @@ def test_passes_through_measurements():
cirq.measure(q1, q2, key='m1', invert_mask=(True, False)),
)
c_orig = cirq.Circuit(circuit)
cirq.ConvertToCzAndSingleGates().optimize_circuit(circuit)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
cirq.ConvertToCzAndSingleGates().optimize_circuit(circuit)
assert circuit == c_orig


Expand All @@ -136,7 +144,8 @@ def test_allow_partial_czs():
cirq.CZPowGate(exponent=0.5, global_shift=-0.5).on(q0, q1),
)
c_orig = cirq.Circuit(circuit)
cirq.ConvertToCzAndSingleGates(allow_partial_czs=True).optimize_circuit(circuit)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
cirq.ConvertToCzAndSingleGates(allow_partial_czs=True).optimize_circuit(circuit)

assert circuit == c_orig

Expand All @@ -147,7 +156,8 @@ def test_allow_partial_czs():
[0, 0, 1, 0],
[0, 0, 0, 1j]]))).on(q0, q1))
# yapf: enable
cirq.ConvertToCzAndSingleGates(allow_partial_czs=True).optimize_circuit(circuit2)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
cirq.ConvertToCzAndSingleGates(allow_partial_czs=True).optimize_circuit(circuit2)
two_qubit_ops = list(circuit2.findall_operations(lambda e: len(e.qubits) == 2))
assert len(two_qubit_ops) == 1
gate = two_qubit_ops[0][1].gate
Expand All @@ -161,14 +171,16 @@ def test_dont_allow_partial_czs():
cirq.CZPowGate(exponent=1, global_shift=-0.5).on(q0, q1),
)
c_orig = cirq.Circuit(circuit)
cirq.ConvertToCzAndSingleGates().optimize_circuit(circuit)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
cirq.ConvertToCzAndSingleGates().optimize_circuit(circuit)
assert circuit == c_orig

circuit = cirq.Circuit(
cirq.CZ(q0, q1) ** 0.5,
)
c_orig = cirq.Circuit(circuit)
cirq.ConvertToCzAndSingleGates(ignore_failures=True).optimize_circuit(circuit)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
cirq.ConvertToCzAndSingleGates(ignore_failures=True).optimize_circuit(circuit)

assert sum(1 for op in circuit.all_operations() if len(op.qubits) > 1) == 2
assert sum(1 for op in circuit.all_operations() if isinstance(op.gate, cirq.CZPowGate)) == 2
Expand Down
5 changes: 4 additions & 1 deletion cirq/optimizers/merge_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import abc
import numpy as np

from cirq import circuits, ops, protocols
from cirq import circuits, ops, protocols, _compat
from cirq.transformers.analytical_decompositions import two_qubit_to_cz

if TYPE_CHECKING:
Expand Down Expand Up @@ -206,6 +206,9 @@ def _flip_kron_order(mat4x4: np.ndarray) -> np.ndarray:
return result


@_compat.deprecated_class(
deadline='v1.0', fix='Use cirq.optimize_for_target_gateset and cirq.CZTargetGateset instead.'
)
class MergeInteractions(MergeInteractionsAbc):
"""Combines series of adjacent one- and two-qubit, non-parametrized gates
operating on a pair of qubits and replaces each series with the minimum
Expand Down
25 changes: 16 additions & 9 deletions cirq/optimizers/merge_interactions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit):
actual = cirq.Circuit(before)
opt = cirq.MergeInteractions()
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
opt = cirq.MergeInteractions()
opt.optimize_circuit(actual)

# Ignore differences that would be caught by follow-up optimizations.
Expand All @@ -45,7 +46,8 @@ def assert_optimization_not_broken(circuit):
global phase and rounding error) as the unitary matrix of the optimized
circuit."""
u_before = circuit.unitary()
cirq.MergeInteractions().optimize_circuit(circuit)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
cirq.MergeInteractions().optimize_circuit(circuit)
u_after = circuit.unitary()

cirq.testing.assert_allclose_up_to_global_phase(u_before, u_after, atol=1e-8)
Expand Down Expand Up @@ -159,15 +161,17 @@ def test_optimizes_single_iswap():
a, b = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.ISWAP(a, b))
assert_optimization_not_broken(c)
cirq.MergeInteractions().optimize_circuit(c)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
cirq.MergeInteractions().optimize_circuit(c)
assert len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 2


def test_optimizes_tagged_partial_cz():
a, b = cirq.LineQubit.range(2)
c = cirq.Circuit((cirq.CZ ** 0.5)(a, b).with_tags('mytag'))
assert_optimization_not_broken(c)
cirq.MergeInteractions(allow_partial_czs=False).optimize_circuit(c)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
cirq.MergeInteractions(allow_partial_czs=False).optimize_circuit(c)
assert (
len([1 for op in c.all_operations() if len(op.qubits) == 2]) == 2
), 'It should take 2 CZ gates to decompose a CZ**0.5 gate'
Expand All @@ -178,7 +182,8 @@ def test_not_decompose_czs():
cirq.CZPowGate(exponent=1, global_shift=-0.5).on(*cirq.LineQubit.range(2))
)
circ_orig = circuit.copy()
cirq.MergeInteractions(allow_partial_czs=False).optimize_circuit(circuit)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
cirq.MergeInteractions(allow_partial_czs=False).optimize_circuit(circuit)
assert circ_orig == circuit


Expand All @@ -195,7 +200,8 @@ def test_not_decompose_czs():
),
)
def test_decompose_partial_czs(circuit):
optimizer = cirq.MergeInteractions(allow_partial_czs=False)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
optimizer = cirq.MergeInteractions(allow_partial_czs=False)
optimizer.optimize_circuit(circuit)

cz_gates = [
Expand All @@ -213,8 +219,8 @@ def test_not_decompose_partial_czs():
circuit = cirq.Circuit(
cirq.CZPowGate(exponent=0.1, global_shift=-0.5)(*cirq.LineQubit.range(2)),
)

optimizer = cirq.MergeInteractions(allow_partial_czs=True)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
optimizer = cirq.MergeInteractions(allow_partial_czs=True)
optimizer.optimize_circuit(circuit)

cz_gates = [
Expand Down Expand Up @@ -247,7 +253,8 @@ def clean_up(operations):
yield operations
yield Marker()(a, b)

optimizer = cirq.MergeInteractions(allow_partial_czs=False, post_clean_up=clean_up)
with cirq.testing.assert_deprecated("Use cirq.optimize_for_target_gateset", deadline='v1.0'):
optimizer = cirq.MergeInteractions(allow_partial_czs=False, post_clean_up=clean_up)
optimizer.optimize_circuit(circuit)
circuit = cirq.drop_empty_moments(circuit)

Expand Down
12 changes: 12 additions & 0 deletions cirq/protocols/json_test_data/CZTargetGateset.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[
{
"cirq_type": "CZTargetGateset",
"atol": 1e-06,
"allow_partial_czs": false
},
{
"cirq_type": "CZTargetGateset",
"atol": 1e-08,
"allow_partial_czs": true
}
]
4 changes: 4 additions & 0 deletions cirq/protocols/json_test_data/CZTargetGateset.repr
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
cirq.CZTargetGateset(atol=1e-06, allow_partial_czs=False),
cirq.CZTargetGateset(atol=1e-08, allow_partial_czs=True),
]
1 change: 1 addition & 0 deletions cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
'OperationTarget',
# Abstract base class for creating compilation targets.
'CompilationTargetGateset',
'TwoQubitCompilationTargetGateset',
# Circuit optimizers are function-like. Only attributes
# are ignore_failures, tolerance, and other feature flags
'AlignLeft',
Expand Down
2 changes: 2 additions & 0 deletions cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

from cirq.transformers.target_gatesets import (
CompilationTargetGateset,
CZTargetGateset,
TwoQubitCompilationTargetGateset,
)

from cirq.transformers.align import align_left, align_right
Expand Down
7 changes: 6 additions & 1 deletion cirq/transformers/target_gatesets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,9 @@

"""Gatesets which can act as compilation targets in Cirq."""

from cirq.transformers.target_gatesets.compilation_target_gateset import CompilationTargetGateset
from cirq.transformers.target_gatesets.compilation_target_gateset import (
CompilationTargetGateset,
TwoQubitCompilationTargetGateset,
)

from cirq.transformers.target_gatesets.cz_gateset import CZTargetGateset
Loading

0 comments on commit 627e12b

Please sign in to comment.