From 8301712810e86fd5795f822c10481461eb17a546 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Mon, 18 Jul 2022 15:23:43 -0700 Subject: [PATCH] Support multi-qubit measurements in deferred measurement transformer (#5787) * Support multi-qubit measurements in deferred measurement transformer * mypy * invert if branch * docstring --- .../transformers/measurement_transformers.py | 26 +++++++-------- .../measurement_transformers_test.py | 32 +++++++++++++++---- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/cirq-core/cirq/transformers/measurement_transformers.py b/cirq-core/cirq/transformers/measurement_transformers.py index 24583e5bf1d..2e101bef4a4 100644 --- a/cirq-core/cirq/transformers/measurement_transformers.py +++ b/cirq-core/cirq/transformers/measurement_transformers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union from cirq import ops, protocols, value @@ -81,9 +82,7 @@ def defer_measurements( A circuit with equivalent logic, but all measurements at the end of the circuit. Raises: - ValueError: If sympy-based classical conditions are used, or if - conditions based on multi-qubit measurements exist. (The latter of - these is planned to be implemented soon). + ValueError: If sympy-based classical conditions are used. NotImplementedError: When attempting to defer a measurement with a confusion map. (https://github.com/quantumlib/Cirq/issues/5482) """ @@ -111,23 +110,22 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': elif protocols.is_measurement(op): return [defer(op, None) for op in protocols.decompose_once(op)] elif op.classical_controls: - controls = [] + new_op = op.without_classical_controls() for c in op.classical_controls: if isinstance(c, value.KeyCondition): if c.key not in measurement_qubits: raise ValueError(f'Deferred measurement for key={c.key} not found.') - qubits = measurement_qubits[c.key] - if len(qubits) != 1: - # TODO: Multi-qubit conditions require - # https://github.com/quantumlib/Cirq/issues/4512 - # Remember to update docstring above once this works. - raise ValueError('Only single qubit conditions are allowed.') - controls.extend(qubits) + qs = measurement_qubits[c.key] + if len(qs) == 1: + control_values: Any = range(1, qs[0].dimension) + else: + all_values = itertools.product(*[range(q.dimension) for q in qs]) + anything_but_all_zeros = tuple(itertools.islice(all_values, 1, None)) + control_values = ops.SumOfProducts(anything_but_all_zeros) + new_op = new_op.controlled_by(*qs, control_values=control_values) else: raise ValueError('Only KeyConditions are allowed.') - return op.without_classical_controls().controlled_by( - *controls, control_values=[tuple(range(1, q.dimension)) for q in controls] - ) + return new_op return op circuit = transformer_primitives.map_operations_and_unroll( diff --git a/cirq-core/cirq/transformers/measurement_transformers_test.py b/cirq-core/cirq/transformers/measurement_transformers_test.py index cd1041f2860..1308614a82d 100644 --- a/cirq-core/cirq/transformers/measurement_transformers_test.py +++ b/cirq-core/cirq/transformers/measurement_transformers_test.py @@ -225,6 +225,31 @@ def test_multi_qubit_measurements(): ) +def test_multi_qubit_control(): + q0, q1, q2 = cirq.LineQubit.range(3) + circuit = cirq.Circuit( + cirq.measure(q0, q1, key='a'), + cirq.X(q2).with_classical_controls('a'), + cirq.measure(q2, key='b'), + ) + assert_equivalent_to_deferred(circuit) + deferred = cirq.defer_measurements(circuit) + q_ma0 = _MeasurementQid('a', q0) + q_ma1 = _MeasurementQid('a', q1) + cirq.testing.assert_same_circuits( + deferred, + cirq.Circuit( + cirq.CX(q0, q_ma0), + cirq.CX(q1, q_ma1), + cirq.X(q2).controlled_by( + q_ma0, q_ma1, control_values=cirq.SumOfProducts(((0, 1), (1, 0), (1, 1))) + ), + cirq.measure(q_ma0, q_ma1, key='a'), + cirq.measure(q2, key='b'), + ), + ) + + def test_diagram(): q0, q1, q2, q3 = cirq.LineQubit.range(4) circuit = cirq.Circuit( @@ -270,13 +295,6 @@ def test_repr(qid: _MeasurementQid): test_repr(_MeasurementQid('0:1:a', cirq.LineQid(9, 4))) -def test_multi_qubit_control(): - q0, q1 = cirq.LineQubit.range(2) - circuit = cirq.Circuit(cirq.measure(q0, q1, key='a'), cirq.X(q1).with_classical_controls('a')) - with pytest.raises(ValueError, match='Only single qubit conditions are allowed'): - _ = cirq.defer_measurements(circuit) - - def test_sympy_control(): q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit(