Skip to content

Commit

Permalink
Add cirq.eject_z transformer to replace cirq.EjectZ (#4955)
Browse files Browse the repository at this point in the history
* Add eject_z transformer to replace EjectZ

* Replaces usages of EjectZ with eject_z

* Reorder cleanup transformers
  • Loading branch information
tanujkhattar authored Feb 7, 2022
1 parent 8b64834 commit 41312e8
Show file tree
Hide file tree
Showing 13 changed files with 636 additions and 157 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@
decompose_two_qubit_interaction_into_four_fsim_gates,
drop_empty_moments,
drop_negligible_operations,
eject_z,
expand_composite,
is_negligible_turn,
map_moments,
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ion/ion_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _cleanup_operations(operations: List[ops.Operation]):
circuit = circuits.Circuit(operations)
optimizers.merge_single_qubit_gates.merge_single_qubit_gates_into_phased_x_z(circuit)
optimizers.eject_phased_paulis.EjectPhasedPaulis().optimize_circuit(circuit)
optimizers.eject_z.EjectZ().optimize_circuit(circuit)
circuit = transformers.eject_z(circuit)
circuit = circuits.Circuit(circuit.all_operations(), strategy=circuits.InsertStrategy.EARLIEST)
return list(circuit.all_operations())

Expand Down
125 changes: 9 additions & 116 deletions cirq-core/cirq/optimizers/eject_z.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,13 @@

"""An optimization pass that pushes Z gates later and later in the circuit."""

from typing import cast, Dict, Iterable, List, Optional, Tuple
from collections import defaultdict
import numpy as np
import sympy
from cirq import circuits, transformers

from cirq import circuits, ops, protocols
from cirq.transformers.analytical_decompositions import single_qubit_decompositions


def _is_integer(n):
return np.isclose(n, np.round(n))


def _is_swaplike(op: ops.Operation):
if isinstance(op.gate, ops.SwapPowGate):
return op.gate.exponent == 1

if isinstance(op.gate, ops.ISwapPowGate):
return _is_integer((op.gate.exponent - 1) / 2)

if isinstance(op.gate, ops.FSimGate):
return _is_integer(op.gate.theta / np.pi - 1 / 2)

return False
# from cirq.transformers import eject_z
from cirq._compat import deprecated_class


@deprecated_class(deadline='v1.0', fix='Use cirq.eject_z instead.')
class EjectZ:
"""Pushes Z gates towards the end of the circuit.
Expand All @@ -62,96 +43,8 @@ def __init__(self, tolerance: float = 0.0, eject_parameterized: bool = False) ->
self.eject_parameterized = eject_parameterized

def optimize_circuit(self, circuit: circuits.Circuit):
# Tracks qubit phases (in half turns; multiply by pi to get radians).
qubit_phase: Dict[ops.Qid, float] = defaultdict(lambda: 0)
deletions: List[Tuple[int, ops.Operation]] = []
replacements: List[Tuple[int, ops.Operation, ops.Operation]] = []
insertions: List[Tuple[int, ops.Operation]] = []
phased_xz_replacements: Dict[Tuple[int, ops.Qid], int] = {}

def dump_tracked_phase(qubits: Iterable[ops.Qid], index: int) -> None:
"""Zeroes qubit_phase entries by emitting Z gates."""
for q in qubits:
p = qubit_phase[q]
qubit_phase[q] = 0
if single_qubit_decompositions.is_negligible_turn(p, self.tolerance):
continue
dumped = False
moment_index = circuit.prev_moment_operating_on([q], index)
if moment_index is not None:
op = circuit.moments[moment_index][q]
if op and isinstance(op.gate, ops.PhasedXZGate):
# Attach z-rotation to replacing PhasedXZ gate.
idx = phased_xz_replacements[moment_index, q]
_, _, repl_op = replacements[idx]
gate = cast(ops.PhasedXZGate, repl_op.gate)
repl_op = gate.with_z_exponent(p * 2).on(q)
replacements[idx] = (moment_index, op, repl_op)
dumped = True
if not dumped:
# Add a new Z gate
dump_op = ops.Z(q) ** (p * 2)
insertions.append((index, dump_op))

for moment_index, moment in enumerate(circuit):
for op in moment.operations:
# Move Z gates into tracked qubit phases.
h = _try_get_known_z_half_turns(op, self.eject_parameterized)
if h is not None:
q = op.qubits[0]
qubit_phase[q] += h / 2
deletions.append((moment_index, op))
continue

# Z gate before measurement is a no-op. Drop tracked phase.
if isinstance(op.gate, ops.MeasurementGate):
for q in op.qubits:
qubit_phase[q] = 0

# If there's no tracked phase, we can move on.
phases = [qubit_phase[q] for q in op.qubits]
if not isinstance(op.gate, ops.PhasedXZGate) and all(
single_qubit_decompositions.is_negligible_turn(p, self.tolerance)
for p in phases
):
continue

if _is_swaplike(op):
a, b = op.qubits
qubit_phase[a], qubit_phase[b] = qubit_phase[b], qubit_phase[a]
continue

# Try to move the tracked phasing over the operation.
phased_op = op
for i, p in enumerate(phases):
if not single_qubit_decompositions.is_negligible_turn(p, self.tolerance):
phased_op = protocols.phase_by(phased_op, -p, i, default=None)
if phased_op is not None:
gate = phased_op.gate
if isinstance(gate, ops.PhasedXZGate) and (
self.eject_parameterized or not protocols.is_parameterized(gate.z_exponent)
):
qubit = phased_op.qubits[0]
qubit_phase[qubit] += gate.z_exponent / 2
phased_op = gate.with_z_exponent(0).on(qubit)
repl_idx = len(replacements)
phased_xz_replacements[moment_index, qubit] = repl_idx
replacements.append((moment_index, op, phased_op))
else:
dump_tracked_phase(op.qubits, moment_index)

dump_tracked_phase(qubit_phase.keys(), len(circuit))
circuit.batch_remove(deletions)
circuit.batch_replace(replacements)
circuit.batch_insert(insertions)


def _try_get_known_z_half_turns(op: ops.Operation, eject_parameterized: bool) -> Optional[float]:
if not isinstance(op, ops.GateOperation):
return None
if not isinstance(op.gate, ops.ZPowGate):
return None
h = op.gate.exponent
if not eject_parameterized and isinstance(h, sympy.Basic):
return None
return h
circuit._moments = [
*transformers.eject_z(
circuit, atol=self.tolerance, eject_parameterized=self.eject_parameterized
)
]
52 changes: 21 additions & 31 deletions cirq-core/cirq/optimizers/eject_z_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,36 @@
import sympy

import cirq
from cirq.optimizers.eject_z import _try_get_known_z_half_turns


def assert_optimizes(
before: cirq.Circuit, expected: cirq.Circuit, eject_parameterized: bool = False
):
opt = cirq.EjectZ(eject_parameterized=eject_parameterized)
with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'):
opt = cirq.EjectZ(eject_parameterized=eject_parameterized)

if cirq.has_unitary(before):
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
before, expected, atol=1e-8
)
if cirq.has_unitary(before):
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
before, expected, atol=1e-8
)

circuit = before.copy()
opt.optimize_circuit(circuit)
opt.optimize_circuit(expected)
circuit = before.copy()
opt.optimize_circuit(circuit)
opt.optimize_circuit(expected)

cirq.testing.assert_same_circuits(circuit, expected)
cirq.testing.assert_same_circuits(circuit, expected)

# And it should be idempotent.
opt.optimize_circuit(circuit)
cirq.testing.assert_same_circuits(circuit, expected)
# And it should be idempotent.
opt.optimize_circuit(circuit)
cirq.testing.assert_same_circuits(circuit, expected)


def assert_removes_all_z_gates(circuit: cirq.Circuit, eject_parameterized: bool = True):
opt = cirq.EjectZ(eject_parameterized=eject_parameterized)
with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'):
opt = cirq.EjectZ(eject_parameterized=eject_parameterized)
optimized = circuit.copy()
opt.optimize_circuit(optimized)
for op in optimized.all_operations():
assert _try_get_known_z_half_turns(op, eject_parameterized) is None
if isinstance(op.gate, cirq.PhasedXZGate) and (
eject_parameterized or not cirq.is_parameterized(op.gate.z_exponent)
):
Expand Down Expand Up @@ -373,7 +373,8 @@ def test_swap():
original = cirq.Circuit([cirq.rz(0.123).on(a), cirq.SWAP(a, b)])
optimized = original.copy()

cirq.EjectZ().optimize_circuit(optimized)
with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'):
cirq.EjectZ().optimize_circuit(optimized)
optimized = cirq.drop_empty_moments(optimized)

assert optimized[0].operations == (cirq.SWAP(a, b),)
Expand All @@ -384,19 +385,14 @@ def test_swap():
)


@pytest.mark.parametrize('exponent', (0, 2, 1.1, -2, -1.6))
def test_not_a_swap(exponent):
a, b = cirq.LineQubit.range(2)
assert not cirq.optimizers.eject_z._is_swaplike(cirq.SWAP(a, b) ** exponent)


@pytest.mark.parametrize('theta', (np.pi / 2, -np.pi / 2, np.pi / 2 + 5 * np.pi))
def test_swap_fsim(theta):
a, b = cirq.LineQubit.range(2)
original = cirq.Circuit([cirq.rz(0.123).on(a), cirq.FSimGate(theta=theta, phi=0.123).on(a, b)])
optimized = original.copy()

cirq.EjectZ().optimize_circuit(optimized)
with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'):
cirq.EjectZ().optimize_circuit(optimized)
optimized = cirq.drop_empty_moments(optimized)

assert optimized[0].operations == (cirq.FSimGate(theta=theta, phi=0.123).on(a, b),)
Expand All @@ -407,19 +403,13 @@ def test_swap_fsim(theta):
)


@pytest.mark.parametrize('theta', (0, 5 * np.pi, -np.pi))
def test_not_a_swap_fsim(theta):
a, b = cirq.LineQubit.range(2)
assert not cirq.optimizers.eject_z._is_swaplike(cirq.FSimGate(theta=theta, phi=0.456).on(a, b))


@pytest.mark.parametrize('exponent', (1, -1))
def test_swap_iswap(exponent):
a, b = cirq.LineQubit.range(2)
original = cirq.Circuit([cirq.rz(0.123).on(a), cirq.ISWAP(a, b) ** exponent])
optimized = original.copy()

cirq.EjectZ().optimize_circuit(optimized)
with cirq.testing.assert_deprecated("Use cirq.eject_z", deadline='v1.0'):
cirq.EjectZ().optimize_circuit(optimized)
optimized = cirq.drop_empty_moments(optimized)

assert optimized[0].operations == (cirq.ISWAP(a, b) ** exponent,)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/optimizers/merge_interactions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit):
followup_optimizations: List[Callable[[cirq.Circuit], None]] = [
cirq.merge_single_qubit_gates_into_phased_x_z,
cirq.EjectPhasedPaulis().optimize_circuit,
cirq.EjectZ().optimize_circuit,
]
for post in followup_optimizations:
post(actual)
post(expected)

followup_transformers: List[cirq.TRANSFORMER] = [
cirq.eject_z,
cirq.drop_negligible_operations,
cirq.drop_empty_moments,
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def assert_optimizes(before: cirq.Circuit, expected: cirq.Circuit, **kwargs):
followup_optimizations: List[Callable[[cirq.Circuit], None]] = [
cirq.merge_single_qubit_gates_into_phased_x_z,
cirq.EjectPhasedPaulis().optimize_circuit,
cirq.EjectZ().optimize_circuit,
]
for post in followup_optimizations:
post(actual)
post(expected)

followup_transformers: List[cirq.TRANSFORMER] = [
cirq.eject_z,
cirq.drop_negligible_operations,
cirq.drop_empty_moments,
]
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@

from cirq.transformers.drop_negligible_operations import drop_negligible_operations

from cirq.transformers.eject_z import eject_z

from cirq.transformers.synchronize_terminal_measurements import synchronize_terminal_measurements

from cirq.transformers.transformer_api import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from cirq import ops, linalg, protocols, circuits
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
from cirq.transformers.eject_z import eject_z
from cirq.optimizers import (
eject_z,
eject_phased_paulis,
merge_single_qubit_gates,
)
Expand Down Expand Up @@ -165,7 +165,7 @@ def _cleanup_operations(operations: Sequence[ops.Operation]):
circuit = circuits.Circuit(operations)
merge_single_qubit_gates.merge_single_qubit_gates_into_phased_x_z(circuit)
eject_phased_paulis.EjectPhasedPaulis().optimize_circuit(circuit)
eject_z.EjectZ().optimize_circuit(circuit)
circuit = eject_z(circuit)
circuit = circuits.Circuit(circuit.all_operations(), strategy=circuits.InsertStrategy.EARLIEST)
return list(circuit.all_operations())

Expand Down
Loading

0 comments on commit 41312e8

Please sign in to comment.