diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 640612f63c3..3d3aade6008 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -362,6 +362,7 @@ decompose_two_qubit_interaction_into_four_fsim_gates, drop_empty_moments, drop_negligible_operations, + eject_z, expand_composite, is_negligible_turn, map_moments, diff --git a/cirq-core/cirq/ion/ion_decomposition.py b/cirq-core/cirq/ion/ion_decomposition.py index 0dfeda160b3..729bb5879aa 100644 --- a/cirq-core/cirq/ion/ion_decomposition.py +++ b/cirq-core/cirq/ion/ion_decomposition.py @@ -54,7 +54,7 @@ def _cleanup_operations(operations: List[ops.Operation]): circuit = circuits.Circuit(operations) optimizers.merge_single_qubit_gates.merge_single_qubit_gates_into_phased_x_z(circuit) optimizers.eject_phased_paulis.EjectPhasedPaulis().optimize_circuit(circuit) - optimizers.eject_z.EjectZ().optimize_circuit(circuit) + circuit = transformers.eject_z(circuit) circuit = circuits.Circuit(circuit.all_operations(), strategy=circuits.InsertStrategy.EARLIEST) return list(circuit.all_operations()) diff --git a/cirq-core/cirq/optimizers/eject_z.py b/cirq-core/cirq/optimizers/eject_z.py index c0e57cfeec9..11954b2f79b 100644 --- a/cirq-core/cirq/optimizers/eject_z.py +++ b/cirq-core/cirq/optimizers/eject_z.py @@ -14,32 +14,13 @@ """An optimization pass that pushes Z gates later and later in the circuit.""" -from typing import cast, Dict, Iterable, List, Optional, Tuple -from collections import defaultdict -import numpy as np -import sympy +from cirq import circuits, transformers -from cirq import circuits, ops, protocols -from cirq.transformers.analytical_decompositions import single_qubit_decompositions - - -def _is_integer(n): - return np.isclose(n, np.round(n)) - - -def _is_swaplike(op: ops.Operation): - if isinstance(op.gate, ops.SwapPowGate): - return op.gate.exponent == 1 - - if isinstance(op.gate, ops.ISwapPowGate): - return _is_integer((op.gate.exponent - 1) / 2) - - if isinstance(op.gate, ops.FSimGate): - return _is_integer(op.gate.theta / np.pi - 1 / 2) - - return False +# from cirq.transformers import eject_z +from cirq._compat import deprecated_class +@deprecated_class(deadline='v1.0', fix='Use cirq.eject_z instead.') class EjectZ: """Pushes Z gates towards the end of the circuit. @@ -62,96 +43,8 @@ def __init__(self, tolerance: float = 0.0, eject_parameterized: bool = False) -> self.eject_parameterized = eject_parameterized def optimize_circuit(self, circuit: circuits.Circuit): - # Tracks qubit phases (in half turns; multiply by pi to get radians). - qubit_phase: Dict[ops.Qid, float] = defaultdict(lambda: 0) - deletions: List[Tuple[int, ops.Operation]] = [] - replacements: List[Tuple[int, ops.Operation, ops.Operation]] = [] - insertions: List[Tuple[int, ops.Operation]] = [] - phased_xz_replacements: Dict[Tuple[int, ops.Qid], int] = {} - - def dump_tracked_phase(qubits: Iterable[ops.Qid], index: int) -> None: - """Zeroes qubit_phase entries by emitting Z gates.""" - for q in qubits: - p = qubit_phase[q] - qubit_phase[q] = 0 - if single_qubit_decompositions.is_negligible_turn(p, self.tolerance): - continue - dumped = False - moment_index = circuit.prev_moment_operating_on([q], index) - if moment_index is not None: - op = circuit.moments[moment_index][q] - if op and isinstance(op.gate, ops.PhasedXZGate): - # Attach z-rotation to replacing PhasedXZ gate. - idx = phased_xz_replacements[moment_index, q] - _, _, repl_op = replacements[idx] - gate = cast(ops.PhasedXZGate, repl_op.gate) - repl_op = gate.with_z_exponent(p * 2).on(q) - replacements[idx] = (moment_index, op, repl_op) - dumped = True - if not dumped: - # Add a new Z gate - dump_op = ops.Z(q) ** (p * 2) - insertions.append((index, dump_op)) - - for moment_index, moment in enumerate(circuit): - for op in moment.operations: - # Move Z gates into tracked qubit phases. - h = _try_get_known_z_half_turns(op, self.eject_parameterized) - if h is not None: - q = op.qubits[0] - qubit_phase[q] += h / 2 - deletions.append((moment_index, op)) - continue - - # Z gate before measurement is a no-op. Drop tracked phase. - if isinstance(op.gate, ops.MeasurementGate): - for q in op.qubits: - qubit_phase[q] = 0 - - # If there's no tracked phase, we can move on. - phases = [qubit_phase[q] for q in op.qubits] - if not isinstance(op.gate, ops.PhasedXZGate) and all( - single_qubit_decompositions.is_negligible_turn(p, self.tolerance) - for p in phases - ): - continue - - if _is_swaplike(op): - a, b = op.qubits - qubit_phase[a], qubit_phase[b] = qubit_phase[b], qubit_phase[a] - continue - - # Try to move the tracked phasing over the operation. - phased_op = op - for i, p in enumerate(phases): - if not single_qubit_decompositions.is_negligible_turn(p, self.tolerance): - phased_op = protocols.phase_by(phased_op, -p, i, default=None) - if phased_op is not None: - gate = phased_op.gate - if isinstance(gate, ops.PhasedXZGate) and ( - self.eject_parameterized or not protocols.is_parameterized(gate.z_exponent) - ): - qubit = phased_op.qubits[0] - qubit_phase[qubit] += gate.z_exponent / 2 - phased_op = gate.with_z_exponent(0).on(qubit) - repl_idx = len(replacements) - phased_xz_replacements[moment_index, qubit] = repl_idx - replacements.append((moment_index, op, phased_op)) - else: - dump_tracked_phase(op.qubits, moment_index) - - dump_tracked_phase(qubit_phase.keys(), len(circuit)) - circuit.batch_remove(deletions) - circuit.batch_replace(replacements) - circuit.batch_insert(insertions) - - -def _try_get_known_z_half_turns(op: ops.Operation, eject_parameterized: bool) -> Optional[float]: - if not isinstance(op, ops.GateOperation): - return None - if not isinstance(op.gate, ops.ZPowGate): - return None - h = op.gate.exponent - if not eject_parameterized and isinstance(h, sympy.Basic): - return None - return h + circuit._moments = [ + *transformers.eject_z( + circuit, atol=self.tolerance, eject_parameterized=self.eject_parameterized + ) + ] diff --git a/cirq-core/cirq/optimizers/eject_z_test.py b/cirq-core/cirq/optimizers/eject_z_test.py index 9876bd3b6a9..22fbee20bd8 100644 --- a/cirq-core/cirq/optimizers/eject_z_test.py +++ b/cirq-core/cirq/optimizers/eject_z_test.py @@ -16,36 +16,36 @@ import sympy import cirq -from cirq.optimizers.eject_z import _try_get_known_z_half_turns def assert_optimizes( before: cirq.Circuit, expected: cirq.Circuit, eject_parameterized: bool = False ): - opt = cirq.EjectZ(eject_parameterized=eject_parameterized) + with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'): + opt = cirq.EjectZ(eject_parameterized=eject_parameterized) - if cirq.has_unitary(before): - cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( - before, expected, atol=1e-8 - ) + if cirq.has_unitary(before): + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + before, expected, atol=1e-8 + ) - circuit = before.copy() - opt.optimize_circuit(circuit) - opt.optimize_circuit(expected) + circuit = before.copy() + opt.optimize_circuit(circuit) + opt.optimize_circuit(expected) - cirq.testing.assert_same_circuits(circuit, expected) + cirq.testing.assert_same_circuits(circuit, expected) - # And it should be idempotent. - opt.optimize_circuit(circuit) - cirq.testing.assert_same_circuits(circuit, expected) + # And it should be idempotent. + opt.optimize_circuit(circuit) + cirq.testing.assert_same_circuits(circuit, expected) def assert_removes_all_z_gates(circuit: cirq.Circuit, eject_parameterized: bool = True): - opt = cirq.EjectZ(eject_parameterized=eject_parameterized) + with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'): + opt = cirq.EjectZ(eject_parameterized=eject_parameterized) optimized = circuit.copy() opt.optimize_circuit(optimized) for op in optimized.all_operations(): - assert _try_get_known_z_half_turns(op, eject_parameterized) is None if isinstance(op.gate, cirq.PhasedXZGate) and ( eject_parameterized or not cirq.is_parameterized(op.gate.z_exponent) ): @@ -373,7 +373,8 @@ def test_swap(): original = cirq.Circuit([cirq.rz(0.123).on(a), cirq.SWAP(a, b)]) optimized = original.copy() - cirq.EjectZ().optimize_circuit(optimized) + with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'): + cirq.EjectZ().optimize_circuit(optimized) optimized = cirq.drop_empty_moments(optimized) assert optimized[0].operations == (cirq.SWAP(a, b),) @@ -384,19 +385,14 @@ def test_swap(): ) -@pytest.mark.parametrize('exponent', (0, 2, 1.1, -2, -1.6)) -def test_not_a_swap(exponent): - a, b = cirq.LineQubit.range(2) - assert not cirq.optimizers.eject_z._is_swaplike(cirq.SWAP(a, b) ** exponent) - - @pytest.mark.parametrize('theta', (np.pi / 2, -np.pi / 2, np.pi / 2 + 5 * np.pi)) def test_swap_fsim(theta): a, b = cirq.LineQubit.range(2) original = cirq.Circuit([cirq.rz(0.123).on(a), cirq.FSimGate(theta=theta, phi=0.123).on(a, b)]) optimized = original.copy() - cirq.EjectZ().optimize_circuit(optimized) + with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'): + cirq.EjectZ().optimize_circuit(optimized) optimized = cirq.drop_empty_moments(optimized) assert optimized[0].operations == (cirq.FSimGate(theta=theta, phi=0.123).on(a, b),) @@ -407,19 +403,13 @@ def test_swap_fsim(theta): ) -@pytest.mark.parametrize('theta', (0, 5 * np.pi, -np.pi)) -def test_not_a_swap_fsim(theta): - a, b = cirq.LineQubit.range(2) - assert not cirq.optimizers.eject_z._is_swaplike(cirq.FSimGate(theta=theta, phi=0.456).on(a, b)) - - @pytest.mark.parametrize('exponent', (1, -1)) def test_swap_iswap(exponent): a, b = cirq.LineQubit.range(2) original = cirq.Circuit([cirq.rz(0.123).on(a), cirq.ISWAP(a, b) ** exponent]) optimized = original.copy() - - cirq.EjectZ().optimize_circuit(optimized) + with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'): + cirq.EjectZ().optimize_circuit(optimized) optimized = cirq.drop_empty_moments(optimized) assert optimized[0].operations == (cirq.ISWAP(a, b) ** exponent,) diff --git a/cirq-core/cirq/optimizers/merge_interactions_test.py b/cirq-core/cirq/optimizers/merge_interactions_test.py index f1dc637748e..c727e5cdb8e 100644 --- a/cirq-core/cirq/optimizers/merge_interactions_test.py +++ b/cirq-core/cirq/optimizers/merge_interactions_test.py @@ -29,13 +29,13 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit): followup_optimizations: List[Callable[[cirq.Circuit], None]] = [ cirq.merge_single_qubit_gates_into_phased_x_z, cirq.EjectPhasedPaulis().optimize_circuit, - cirq.EjectZ().optimize_circuit, ] for post in followup_optimizations: post(actual) post(expected) followup_transformers: List[cirq.TRANSFORMER] = [ + cirq.eject_z, cirq.drop_negligible_operations, cirq.drop_empty_moments, ] diff --git a/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py b/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py index 3dbbcb4fdc2..6508aea145f 100644 --- a/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py +++ b/cirq-core/cirq/optimizers/merge_interactions_to_sqrt_iswap_test.py @@ -40,13 +40,13 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit, **kwargs): followup_optimizations: List[Callable[[cirq.Circuit], None]] = [ cirq.merge_single_qubit_gates_into_phased_x_z, cirq.EjectPhasedPaulis().optimize_circuit, - cirq.EjectZ().optimize_circuit, ] for post in followup_optimizations: post(actual) post(expected) followup_transformers: List[cirq.TRANSFORMER] = [ + cirq.eject_z, cirq.drop_negligible_operations, cirq.drop_empty_moments, ] diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index 70fb448f9dd..cac6ed0875a 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -49,6 +49,8 @@ from cirq.transformers.drop_negligible_operations import drop_negligible_operations +from cirq.transformers.eject_z import eject_z + from cirq.transformers.synchronize_terminal_measurements import synchronize_terminal_measurements from cirq.transformers.transformer_api import ( diff --git a/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz.py b/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz.py index df2509a2440..a78ec1eeebf 100644 --- a/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz.py +++ b/cirq-core/cirq/transformers/analytical_decompositions/two_qubit_to_cz.py @@ -23,8 +23,8 @@ from cirq import ops, linalg, protocols, circuits from cirq.transformers.analytical_decompositions import single_qubit_decompositions +from cirq.transformers.eject_z import eject_z from cirq.optimizers import ( - eject_z, eject_phased_paulis, merge_single_qubit_gates, ) @@ -165,7 +165,7 @@ def _cleanup_operations(operations: Sequence[ops.Operation]): circuit = circuits.Circuit(operations) merge_single_qubit_gates.merge_single_qubit_gates_into_phased_x_z(circuit) eject_phased_paulis.EjectPhasedPaulis().optimize_circuit(circuit) - eject_z.EjectZ().optimize_circuit(circuit) + circuit = eject_z(circuit) circuit = circuits.Circuit(circuit.all_operations(), strategy=circuits.InsertStrategy.EARLIEST) return list(circuit.all_operations()) diff --git a/cirq-core/cirq/transformers/eject_z.py b/cirq-core/cirq/transformers/eject_z.py new file mode 100644 index 00000000000..7ebd12e1047 --- /dev/null +++ b/cirq-core/cirq/transformers/eject_z.py @@ -0,0 +1,145 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer pass that pushes Z gates later and later in the circuit.""" + +from typing import Dict, Iterable, Optional, Tuple, TYPE_CHECKING +from collections import defaultdict +import numpy as np + +from cirq import ops, protocols +from cirq.transformers import transformer_api, transformer_primitives +from cirq.transformers.analytical_decompositions import single_qubit_decompositions + +if TYPE_CHECKING: + import cirq + + +def _is_integer(n): + return np.isclose(n, np.round(n)) + + +def _is_swaplike(gate: 'cirq.Gate'): + if isinstance(gate, ops.SwapPowGate): + return gate.exponent == 1 + + if isinstance(gate, ops.ISwapPowGate): + return _is_integer((gate.exponent - 1) / 2) + + if isinstance(gate, ops.FSimGate): + return _is_integer(gate.theta / np.pi - 1 / 2) + + return False + + +@transformer_api.transformer +def eject_z( + circuit: 'cirq.AbstractCircuit', + *, + context: Optional['cirq.TransformerContext'] = None, + atol: float = 0.0, + eject_parameterized: bool = False, +) -> 'cirq.Circuit': + """Pushes Z gates towards the end of the circuit. + + As the Z gates get pushed they may absorb other Z gates, get absorbed into + measurements, cross CZ gates, cross PhasedXPowGate (aka W) gates (by phasing them), etc. + + Args: + circuit: Input circuit to transform. + context: `cirq.TransformerContext` storing common configurable options for transformers. + atol: Maximum absolute error tolerance. The optimization is + permitted to simply drop negligible combinations of Z gates, + with a threshold determined by this tolerance. + eject_parameterized: If True, the optimization will attempt to eject + parameterized Z gates as well. This may result in other gates + parameterized by symbolic expressions. + Returns: + Copy of the transformed input circuit. + """ + # Tracks qubit phases (in half turns; multiply by pi to get radians). + qubit_phase: Dict[ops.Qid, float] = defaultdict(lambda: 0) + tags_to_ignore = set(context.tags_to_ignore) if context else set() + phased_xz_replacements: Dict[Tuple[int, ops.Operation], ops.PhasedXZGate] = {} + last_phased_xz_op: Dict[ops.Qid, Optional[Tuple[int, ops.Operation]]] = defaultdict( + lambda: None + ) + + def dump_tracked_phase(qubits: Iterable[ops.Qid]) -> 'cirq.OP_TREE': + """Zeroes qubit_phase entries by emitting Z gates.""" + for q in qubits: + p, key = qubit_phase[q], last_phased_xz_op[q] + qubit_phase[q] = 0 + if not (key or single_qubit_decompositions.is_negligible_turn(p, atol)): + yield ops.Z(q) ** (p * 2) + elif key: + phased_xz_replacements[key] = phased_xz_replacements[key].with_z_exponent(p * 2) + + def map_func(op: 'cirq.Operation', moment_index: int) -> 'cirq.OP_TREE': + last_phased_xz_op.update({q: None for q in op.qubits}) + + if tags_to_ignore & set(op.tags): + # Op marked with no-compile, dump phases and do not cross. + return [dump_tracked_phase(op.qubits), op] + + gate = op.gate + # Return if circuit operation. + if gate is None: + return op + + # Swap phases if `op` is a swap operation. + if _is_swaplike(gate): + a, b = op.qubits + qubit_phase[a], qubit_phase[b] = qubit_phase[b], qubit_phase[a] + return op + + # Z gate before measurement is a no-op. Drop tracked phase. + if isinstance(gate, ops.MeasurementGate): + for q in op.qubits: + qubit_phase[q] = 0 + return op + + # Move Z gates into tracked qubit phases. + if isinstance(gate, ops.ZPowGate) and ( + eject_parameterized or not protocols.is_parameterized(gate) + ): + qubit_phase[op.qubits[0]] += gate.exponent / 2 + return [] + + # Try to move the tracked phases over the operation via protocols.phase_by(op) + phased_op = op + for i, p in enumerate([qubit_phase[q] for q in op.qubits]): + if not single_qubit_decompositions.is_negligible_turn(p, atol): + phased_op = protocols.phase_by(phased_op, -p, i, default=None) + if phased_op is None: + return [dump_tracked_phase(op.qubits), op] + + gate = phased_op.gate + if isinstance(gate, ops.PhasedXZGate) and ( + eject_parameterized or not protocols.is_parameterized(gate.z_exponent) + ): + qubit = phased_op.qubits[0] + qubit_phase[qubit] += gate.z_exponent / 2 + gate = gate.with_z_exponent(0) + phased_op = gate.on(qubit) + phased_xz_replacements[moment_index, phased_op] = gate + last_phased_xz_op[qubit] = (moment_index, phased_op) + return phased_op + + circuit = transformer_primitives.map_operations(circuit, map_func).unfreeze(copy=False) + circuit.append(dump_tracked_phase(qubit_phase.keys())) + circuit.batch_replace( + (m, op, g.on(*op.qubits)) for (m, op), g in phased_xz_replacements.items() + ) + return transformer_primitives.unroll_circuit_op(circuit) diff --git a/cirq-core/cirq/transformers/eject_z_test.py b/cirq-core/cirq/transformers/eject_z_test.py new file mode 100644 index 00000000000..ca8743a35f1 --- /dev/null +++ b/cirq-core/cirq/transformers/eject_z_test.py @@ -0,0 +1,448 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import numpy as np +import sympy + +import cirq +from cirq.transformers.eject_z import _is_swaplike + + +def assert_optimizes( + before: cirq.Circuit, + expected: cirq.Circuit, + eject_parameterized: bool = False, + *, + with_context: bool = False, +): + if cirq.has_unitary(before): + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + before, expected, atol=1e-8 + ) + context = cirq.TransformerContext(tags_to_ignore=("nocompile",)) if with_context else None + circuit = cirq.eject_z(before, eject_parameterized=eject_parameterized, context=context) + expected = cirq.eject_z(expected, eject_parameterized=eject_parameterized, context=context) + cirq.testing.assert_same_circuits(circuit, expected) + + # And it should be idempotent. + circuit = cirq.eject_z(before, eject_parameterized=eject_parameterized, context=context) + cirq.testing.assert_same_circuits(circuit, expected) + + +def assert_removes_all_z_gates(circuit: cirq.Circuit, eject_parameterized: bool = True): + optimized = cirq.eject_z(circuit, eject_parameterized=eject_parameterized) + for op in optimized.all_operations(): + # assert _try_get_known_z_half_turns(op, eject_parameterized) is None + if isinstance(op.gate, cirq.PhasedXZGate) and ( + eject_parameterized or not cirq.is_parameterized(op.gate.z_exponent) + ): + assert op.gate.z_exponent == 0 + + if cirq.is_parameterized(circuit): + for a in (0, 0.1, 0.5, 1.0, -1.0, 3.0): + ( + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + cirq.resolve_parameters(circuit, {'a': a}), + cirq.resolve_parameters(optimized, {'a': a}), + atol=1e-8, + ) + ) + else: + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + circuit, optimized, atol=1e-8 + ) + + +def test_single_z_stays(): + q = cirq.NamedQubit('q') + assert_optimizes( + before=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q) ** 0.5]), + ] + ), + expected=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q) ** 0.5]), + ] + ), + ) + + +def test_single_phased_xz_stays(): + gate = cirq.PhasedXZGate(axis_phase_exponent=0.2, x_exponent=0.3, z_exponent=0.4) + q = cirq.NamedQubit('q') + assert_optimizes(before=cirq.Circuit(gate(q)), expected=cirq.Circuit(gate(q))) + + +def test_ignores_xz_and_cz(): + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + assert_optimizes( + before=cirq.Circuit( + [ + cirq.Moment([cirq.X(a) ** 0.5]), + cirq.Moment([cirq.Y(b) ** 0.5]), + cirq.Moment([cirq.CZ(a, b) ** 0.25]), + cirq.Moment([cirq.Y(a) ** 0.5]), + cirq.Moment([cirq.X(b) ** 0.5]), + ] + ), + expected=cirq.Circuit( + [ + cirq.Moment([cirq.X(a) ** 0.5]), + cirq.Moment([cirq.Y(b) ** 0.5]), + cirq.Moment([cirq.CZ(a, b) ** 0.25]), + cirq.Moment([cirq.Y(a) ** 0.5]), + cirq.Moment([cirq.X(b) ** 0.5]), + ] + ), + ) + + +def test_early_z(): + q = cirq.NamedQubit('q') + assert_optimizes( + before=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q) ** 0.5]), + cirq.Moment(), + cirq.Moment(), + ] + ), + expected=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q) ** 0.5]), + cirq.Moment(), + cirq.Moment(), + ] + ), + ) + + +def test_multi_z_merges(): + q = cirq.NamedQubit('q') + assert_optimizes( + before=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q) ** 0.5]), + cirq.Moment([cirq.Z(q) ** 0.25]), + ] + ), + expected=cirq.Circuit( + [ + cirq.Moment(), + cirq.Moment([cirq.Z(q) ** 0.75]), + ] + ), + ) + + +def test_z_pushes_past_xy_and_phases_it(): + q = cirq.NamedQubit('q') + assert_optimizes( + before=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q) ** 0.5]), + cirq.Moment([cirq.Y(q) ** 0.25]), + ] + ), + expected=cirq.Circuit( + [ + cirq.Moment(), + cirq.Moment([cirq.X(q) ** 0.25]), + cirq.Moment([cirq.Z(q) ** 0.5]), + ] + ), + ) + + +def test_z_pushes_past_cz(): + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + assert_optimizes( + before=cirq.Circuit( + [ + cirq.Moment([cirq.Z(a) ** 0.5]), + cirq.Moment([cirq.CZ(a, b) ** 0.25]), + ] + ), + expected=cirq.Circuit( + [ + cirq.Moment(), + cirq.Moment([cirq.CZ(a, b) ** 0.25]), + cirq.Moment([cirq.Z(a) ** 0.5]), + ] + ), + ) + + +def test_measurement_consumes_zs(): + q = cirq.NamedQubit('q') + assert_optimizes( + before=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q) ** 0.5]), + cirq.Moment([cirq.Z(q) ** 0.25]), + cirq.Moment([cirq.measure(q)]), + ] + ), + expected=cirq.Circuit( + [ + cirq.Moment(), + cirq.Moment(), + cirq.Moment([cirq.measure(q)]), + ] + ), + ) + + +def test_unphaseable_causes_earlier_merge_without_size_increase(): + class UnknownGate(cirq.SingleQubitGate): + pass + + u = UnknownGate() + + # pylint: disable=not-callable + q = cirq.NamedQubit('q') + assert_optimizes( + before=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q)]), + cirq.Moment([u(q)]), + cirq.Moment([cirq.Z(q) ** 0.5]), + cirq.Moment([cirq.X(q)]), + cirq.Moment([cirq.Z(q) ** 0.25]), + cirq.Moment([cirq.X(q)]), + cirq.Moment([u(q)]), + ] + ), + expected=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q)]), + cirq.Moment([u(q)]), + cirq.Moment(), + cirq.Moment([cirq.PhasedXPowGate(phase_exponent=-0.5)(q)]), + cirq.Moment(), + cirq.Moment([cirq.PhasedXPowGate(phase_exponent=-0.75).on(q)]), + cirq.Moment([cirq.Z(q) ** 0.75]), + cirq.Moment([u(q)]), + ] + ), + ) + + +@pytest.mark.parametrize( + 'sym', + [ + sympy.Symbol('a'), + sympy.Symbol('a') + 1, + ], +) +def test_symbols_block(sym): + q = cirq.NamedQubit('q') + assert_optimizes( + before=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q)]), + cirq.Moment([cirq.Z(q) ** sym]), + cirq.Moment([cirq.Z(q) ** 0.25]), + ] + ), + expected=cirq.Circuit( + [ + cirq.Moment(), + cirq.Moment([cirq.Z(q) ** sym]), + cirq.Moment([cirq.Z(q) ** 1.25]), + ] + ), + ) + + +@pytest.mark.parametrize( + 'sym', + [ + sympy.Symbol('a'), + sympy.Symbol('a') + 1, + ], +) +def test_symbols_eject(sym): + q = cirq.NamedQubit('q') + assert_optimizes( + before=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q)]), + cirq.Moment([cirq.Z(q) ** sym]), + cirq.Moment([cirq.Z(q) ** 0.25]), + ] + ), + expected=cirq.Circuit( + [ + cirq.Moment(), + cirq.Moment(), + cirq.Moment([cirq.Z(q) ** (sym + 1.25)]), + ] + ), + eject_parameterized=True, + ) + + +def test_removes_zs(): + a = cirq.NamedQubit('a') + b = cirq.NamedQubit('b') + + assert_removes_all_z_gates(cirq.Circuit(cirq.Z(a), cirq.measure(a))) + + assert_removes_all_z_gates(cirq.Circuit(cirq.Z(a), cirq.measure(a, b))) + + assert_removes_all_z_gates(cirq.Circuit(cirq.Z(a), cirq.Z(a), cirq.measure(a))) + + assert_removes_all_z_gates(cirq.Circuit(cirq.Z(a), cirq.measure(a, key='k'))) + + assert_removes_all_z_gates(cirq.Circuit(cirq.Z(a), cirq.X(a), cirq.measure(a))) + + assert_removes_all_z_gates(cirq.Circuit(cirq.Z(a), cirq.X(a), cirq.X(a), cirq.measure(a))) + + assert_removes_all_z_gates( + cirq.Circuit(cirq.Z(a), cirq.Z(b), cirq.CZ(a, b), cirq.CZ(a, b), cirq.measure(a, b)) + ) + + assert_removes_all_z_gates( + cirq.Circuit( + cirq.PhasedXZGate(axis_phase_exponent=0, x_exponent=0, z_exponent=1).on(a), + cirq.measure(a), + ) + ) + + assert_removes_all_z_gates( + cirq.Circuit( + cirq.Z(a) ** sympy.Symbol('a'), + cirq.Z(b) ** (sympy.Symbol('a') + 1), + cirq.CZ(a, b), + cirq.CZ(a, b), + cirq.measure(a, b), + ), + eject_parameterized=True, + ) + + +def test_unknown_operation_blocks(): + q = cirq.NamedQubit('q') + + class UnknownOp(cirq.Operation): + @property + def qubits(self): + return [q] + + def with_qubits(self, *new_qubits): + raise NotImplementedError() + + u = UnknownOp() + + assert_optimizes( + before=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q)]), + cirq.Moment([u]), + ] + ), + expected=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q)]), + cirq.Moment([u]), + ] + ), + ) + + +def test_tagged_nocompile_operation_blocks(): + q = cirq.NamedQubit('q') + u = cirq.Z(q).with_tags("nocompile") + assert_optimizes( + before=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q)]), + cirq.Moment([u]), + ] + ), + expected=cirq.Circuit( + [ + cirq.Moment([cirq.Z(q)]), + cirq.Moment([u]), + ] + ), + with_context=True, + ) + + +def test_swap(): + a, b = cirq.LineQubit.range(2) + original = cirq.Circuit([cirq.rz(0.123).on(a), cirq.SWAP(a, b)]) + optimized = original.copy() + + optimized = cirq.eject_z(optimized) + optimized = cirq.drop_empty_moments(optimized) + + assert optimized[0].operations == (cirq.SWAP(a, b),) + # Note: EjectZ drops `global_phase` from Rz turning it into a Z + assert optimized[1].operations == (cirq.Z(b) ** (0.123 / np.pi),) + cirq.testing.assert_allclose_up_to_global_phase( + cirq.unitary(original), cirq.unitary(optimized), atol=1e-8 + ) + + +@pytest.mark.parametrize('exponent', (0, 2, 1.1, -2, -1.6)) +def test_not_a_swap(exponent): + a, b = cirq.LineQubit.range(2) + assert not _is_swaplike(cirq.SWAP(a, b) ** exponent) + + +@pytest.mark.parametrize('theta', (np.pi / 2, -np.pi / 2, np.pi / 2 + 5 * np.pi)) +def test_swap_fsim(theta): + a, b = cirq.LineQubit.range(2) + original = cirq.Circuit([cirq.rz(0.123).on(a), cirq.FSimGate(theta=theta, phi=0.123).on(a, b)]) + optimized = original.copy() + + optimized = cirq.eject_z(optimized) + optimized = cirq.drop_empty_moments(optimized) + + assert optimized[0].operations == (cirq.FSimGate(theta=theta, phi=0.123).on(a, b),) + # Note: EjectZ drops `global_phase` from Rz turning it into a Z + assert optimized[1].operations == (cirq.Z(b) ** (0.123 / np.pi),) + cirq.testing.assert_allclose_up_to_global_phase( + cirq.unitary(original), cirq.unitary(optimized), atol=1e-8 + ) + + +@pytest.mark.parametrize('theta', (0, 5 * np.pi, -np.pi)) +def test_not_a_swap_fsim(theta): + a, b = cirq.LineQubit.range(2) + assert not _is_swaplike(cirq.FSimGate(theta=theta, phi=0.456).on(a, b)) + + +@pytest.mark.parametrize('exponent', (1, -1)) +def test_swap_iswap(exponent): + a, b = cirq.LineQubit.range(2) + original = cirq.Circuit([cirq.rz(0.123).on(a), cirq.ISWAP(a, b) ** exponent]) + optimized = original.copy() + + optimized = cirq.eject_z(optimized) + optimized = cirq.drop_empty_moments(optimized) + + assert optimized[0].operations == (cirq.ISWAP(a, b) ** exponent,) + # Note: EjectZ drops `global_phase` from Rz turning it into a Z + assert optimized[1].operations == (cirq.Z(b) ** (0.123 / np.pi),) + cirq.testing.assert_allclose_up_to_global_phase( + cirq.unitary(original), cirq.unitary(optimized), atol=1e-8 + ) diff --git a/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py b/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py index 58d6bce64c8..efcca7f00a5 100644 --- a/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py +++ b/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py @@ -32,7 +32,6 @@ def _get_common_cleanup_optimizers(tolerance: float) -> List[Callable[[cirq.Circuit], None]]: return [ cirq.EjectPhasedPaulis(tolerance=tolerance).optimize_circuit, - cirq.EjectZ(tolerance=tolerance).optimize_circuit, ] @@ -166,6 +165,7 @@ def optimized_for_sycamore( for optimizer in opts: optimizer(copy) + copy = cirq.eject_z(copy, atol=tolerance) copy = cirq.drop_negligible_operations(copy, atol=tolerance) ret = cirq.Circuit( diff --git a/docs/google/best_practices.md b/docs/google/best_practices.md index 937634c224b..0794593dd98 100644 --- a/docs/google/best_practices.md +++ b/docs/google/best_practices.md @@ -296,7 +296,7 @@ book-keeping measures rather than physical operations on the device. These virtual Z operations have zero duration and have no cost, if they add no moments to your circuit. In order to guarantee that they do not add moments, you can make sure that virtual Z are aggregated into their own layer. Alternatively, -you can use the `EjectZ` optimizer to propagate these Z gates forward through +you can use the `cirq.eject_z` optimizer to propagate these Z gates forward through commuting operators. See the function `cirq.stratified_circuit` for an automated way to organize gates diff --git a/docs/tutorials/google/spin_echoes.ipynb b/docs/tutorials/google/spin_echoes.ipynb index 969f3f65e1a..e02dcf4c35c 100644 --- a/docs/tutorials/google/spin_echoes.ipynb +++ b/docs/tutorials/google/spin_echoes.ipynb @@ -231,7 +231,7 @@ " if with_optimization:\n", " cirq.MergeInteractionsToSqrtIswap().optimize_circuit(circuit)\n", " cirq.EjectPhasedPaulis().optimize_circuit(circuit)\n", - " cirq.EjectZ().optimize_circuit(circuit)\n", + " circuit = cirq.eject_z(circuit)\n", " circuit = cirq.drop_negligible_operations(circuit)\n", " circuit = cirq.drop_empty_moments(circuit)\n", "\n", @@ -664,7 +664,7 @@ "id": "oKcnKYTE-ms3" }, "source": [ - "You can also use the `cirq.EjectZ` optimizer to attempt to push `cirq.Z` gates towards the end of the circuit." + "You can also use the `cirq.eject_z` optimizer to attempt to push `cirq.Z` gates towards the end of the circuit." ] }, { @@ -713,7 +713,7 @@ } ], "source": [ - "cirq.EjectZ().optimize_circuit(circuit)\n", + "circuit = cirq.eject_z(circuit)\n", "circuit" ] },