From 6e64449dccb8cb2613f5f2aa6df9e32955941414 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Tue, 11 Oct 2022 20:34:15 -0700 Subject: [PATCH] Handle confusion matrices in deferred measurements (#5851) Handles measurement confusion matrices by substituting with a confusion channel. These channels are derived as described in the `_ConfusionChannel` docstring, and evolve the diagonal of the density matrix in the same way that the confusion matrix operates on the classical state distribution. I hadn't really thought of it before, but when writing the test I realized this is _great_ for sampling circuits with classical controls. Way better performance than having to simulate that same number of iterations due to classical logic. --- cirq-core/cirq/ops/gate_operation_test.py | 1 + .../transformers/measurement_transformers.py | 137 +++++++++++++++- .../measurement_transformers_test.py | 146 +++++++++++++++++- 3 files changed, 272 insertions(+), 12 deletions(-) diff --git a/cirq-core/cirq/ops/gate_operation_test.py b/cirq-core/cirq/ops/gate_operation_test.py index 2902556af13..e37d62df183 100644 --- a/cirq-core/cirq/ops/gate_operation_test.py +++ b/cirq-core/cirq/ops/gate_operation_test.py @@ -494,6 +494,7 @@ def all_subclasses(cls): cirq.Pauli, # Private gates. cirq.transformers.analytical_decompositions.two_qubit_to_fsim._BGate, + cirq.transformers.measurement_transformers._ConfusionChannel, cirq.transformers.measurement_transformers._ModAdd, cirq.transformers.routing.visualize_routed_circuit._SwapPrintGate, cirq.ops.raw_types._InverseCompositeGate, diff --git a/cirq-core/cirq/transformers/measurement_transformers.py b/cirq-core/cirq/transformers/measurement_transformers.py index c931f0e24b5..506ce9ad864 100644 --- a/cirq-core/cirq/transformers/measurement_transformers.py +++ b/cirq-core/cirq/transformers/measurement_transformers.py @@ -13,9 +13,11 @@ # limitations under the License. import itertools -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union -from cirq import ops, protocols, value +import numpy as np + +from cirq import linalg, ops, protocols, value from cirq.transformers import transformer_api, transformer_primitives from cirq.transformers.synchronize_terminal_measurements import find_terminal_measurements @@ -96,17 +98,19 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': return op gate = op.gate if isinstance(gate, ops.MeasurementGate): - if gate.confusion_map: - raise NotImplementedError( - "Deferring confused measurement is not implemented, but found " - f"measurement with key={gate.key} and non-empty confusion map." - ) key = value.MeasurementKey.parse_serialized(gate.key) targets = [_MeasurementQid(key, q) for q in op.qubits] measurement_qubits[key] = targets cxs = [_mod_add(q, target) for q, target in zip(op.qubits, targets)] + confusions = [ + _ConfusionChannel(m, [op.qubits[i].dimension for i in indexes]).on( + *[targets[i] for i in indexes] + ) + for indexes, m in gate.confusion_map.items() + ] + cxs = [_mod_add(q, target) for q, target in zip(op.qubits, targets)] xs = [ops.X(targets[i]) for i, b in enumerate(gate.full_invert_mask()) if b] - return cxs + xs + return cxs + confusions + xs elif protocols.is_measurement(op): return [defer(op, None) for op in protocols.decompose_once(op)] elif op.classical_controls: @@ -229,6 +233,123 @@ def flip_inversion(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': ).unfreeze() +class _ConfusionChannel(ops.Gate): + r"""The quantum equivalent of a confusion matrix. + + This gate performs a complete dephasing of the input qubits, and then confuses the remaining + diagonal components per the input confusion matrix. + + For a classical confusion matrix, the quantum equivalent is a channel that can be calculated + by transposing the matrix, taking the square root of each term, and forming a Kraus sequence + of each term individually and the rest zeroed out. For example, consider the confusion matrix + + $$ + \begin{aligned} + M_C =& \begin{bmatrix} + 0.8 & 0.2 \\ + 0.1 & 0.9 + \end{bmatrix} + \end{aligned} + $$ + + If $a$ and $b (= 1-a)$ are probabilities of two possible classical states for a measurement, + the confusion matrix operates on those probabilities as + + $$ + (a, b) M_C = (0.8a + 0.1b, 0.2a + 0.9b) + $$ + + This is equivalent to the following Kraus representation operating on a diagonal of a density + matrix: + + $$ + \begin{aligned} + M_0 =& \begin{bmatrix} + \sqrt{0.8} & 0 \\ + 0 & 0 + \end{bmatrix} + \\ + M_1 =& \begin{bmatrix} + 0 & \sqrt{0.1} \\ + 0 & 0 + \end{bmatrix} + \\ + M_2 =& \begin{bmatrix} + 0 & 0 \\ + \sqrt{0.2} & 0 + \end{bmatrix} + \\ + M_3 =& \begin{bmatrix} + 0 & 0 \\ + 0 & \sqrt{0.9} + \end{bmatrix} + \end{aligned} + \\ + $$ + Then for + $$ + \begin{aligned} + \rho =& \begin{bmatrix} + a & ? \\ + ? & b + \end{bmatrix} + \end{aligned} + \\ + \\ + $$ + the evolution of + $$ + \rho \rightarrow M_0 \rho M_0^\dagger + + M_1 \rho M_1^\dagger + + M_2 \rho M_2^\dagger + + M_3 \rho M_3^\dagger + $$ + gives the result + $$ + \begin{aligned} + \rho =& \begin{bmatrix} + 0.8a + 0.1b & 0 \\ + 0 & 0.2a + 0.9b + \end{bmatrix} + \end{aligned} + \\ + $$ + + Thus in a deferred measurement scenario, applying this channel to the ancilla qubit will model + the noise distribution that would have been caused by the confusion matrix. The math + generalizes cleanly to n-dimensional measurements as well. + """ + + def __init__(self, confusion_map: np.ndarray, shape: Sequence[int]): + if confusion_map.ndim != 2: + raise ValueError('Confusion map must be 2D.') + row_count, col_count = confusion_map.shape + if row_count != col_count: + raise ValueError('Confusion map must be square.') + if row_count != np.prod(shape): + raise ValueError('Confusion map size does not match qubit shape.') + kraus = [] + for r in range(row_count): + for c in range(col_count): + v = confusion_map[r, c] + if v < 0: + raise ValueError('Confusion map has negative probabilities.') + if v > 0: + m = np.zeros(confusion_map.shape) + m[c, r] = np.sqrt(v) + kraus.append(m) + if not linalg.is_cptp(kraus_ops=kraus): + raise ValueError('Confusion map has invalid probabilities.') + self._shape = tuple(shape) + self._kraus = tuple(kraus) + + def _qid_shape_(self) -> Tuple[int, ...]: + return self._shape + + def _kraus_(self) -> Tuple[np.ndarray, ...]: + return self._kraus + + @value.value_equality class _ModAdd(ops.ArithmeticGate): """Adds two qudits of the same dimension. diff --git a/cirq-core/cirq/transformers/measurement_transformers_test.py b/cirq-core/cirq/transformers/measurement_transformers_test.py index 31e8ac995f8..c95825a43d7 100644 --- a/cirq-core/cirq/transformers/measurement_transformers_test.py +++ b/cirq-core/cirq/transformers/measurement_transformers_test.py @@ -328,12 +328,150 @@ def test_sympy_control(): def test_confusion_map(): q0, q1 = cirq.LineQubit.range(2) circuit = cirq.Circuit( - cirq.measure(q0, q1, key='a', confusion_map={(0,): np.array([[0.9, 0.1], [0.1, 0.9]])}), + cirq.H(q0), + cirq.measure(q0, key='a', confusion_map={(0,): np.array([[0.8, 0.2], [0.1, 0.9]])}), + cirq.X(q1).with_classical_controls('a'), + cirq.measure(q1, key='b'), + ) + deferred = cirq.defer_measurements(circuit) + + # We use DM simulator because the deferred circuit has channels + sim = cirq.DensityMatrixSimulator() + + # 10K samples would take a long time if we had not deferred the measurements, as we'd have to + # run 10K simulations. Here with DM simulator it's 100ms. + result = sim.sample(deferred, repetitions=10_000) + + # This should be 5_000 due to the H, then 1_000 more due to 0's flipping to 1's with p=0.2, and + # then 500 less due to 1's flipping to 0's with p=0.1, so 5_500. + assert 5_100 <= np.sum(result['a']) <= 5_900 + assert np.all(result['a'] == result['b']) + + +def test_confusion_map_density_matrix(): + q0, q1 = cirq.LineQubit.range(2) + p_q0 = 0.3 # probability to measure 1 for q0 + confusion = np.array([[0.8, 0.2], [0.1, 0.9]]) + circuit = cirq.Circuit( + # Rotate q0 such that the probability to measure 1 is p_q0 + cirq.X(q0) ** (np.arcsin(np.sqrt(p_q0)) * 2 / np.pi), + cirq.measure(q0, key='a', confusion_map={(0,): confusion}), + cirq.X(q1).with_classical_controls('a'), + ) + deferred = cirq.defer_measurements(circuit) + q_order = (q0, q1, _MeasurementQid('a', q0)) + rho = cirq.final_density_matrix(deferred, qubit_order=q_order).reshape((2,) * 6) + + # q0 density matrix should be a diagonal with the probabilities [1-p, p]. + q0_probs = [1 - p_q0, p_q0] + assert np.allclose(cirq.partial_trace(rho, [0]), np.diag(q0_probs)) + + # q1 and the ancilla should both be the q1 probs matmul the confusion matrix. + expected = np.diag(q0_probs @ confusion) + assert np.allclose(cirq.partial_trace(rho, [1]), expected) + assert np.allclose(cirq.partial_trace(rho, [2]), expected) + + +def test_confusion_map_invert_mask_ordering(): + q0 = cirq.LineQubit(0) + # Confusion map sets the measurement to zero, and the invert mask changes it to one. + # If these are run out of order then the result would be zero. + circuit = cirq.Circuit( + cirq.measure( + q0, key='a', confusion_map={(0,): np.array([[1, 0], [1, 0]])}, invert_mask=(1,) + ), + cirq.I(q0), + ) + assert_equivalent_to_deferred(circuit) + + +def test_confusion_map_qudits(): + q0 = cirq.LineQid(0, dimension=3) + # First op takes q0 to superposed state, then confusion map measures 2 regardless. + circuit = cirq.Circuit( + cirq.XPowGate(dimension=3).on(q0) ** 1.3, + cirq.measure( + q0, key='a', confusion_map={(0,): np.array([[0, 0, 1], [0, 0, 1], [0, 0, 1]])} + ), + cirq.IdentityGate(qid_shape=(3,)).on(q0), + ) + assert_equivalent_to_deferred(circuit) + + +def test_multi_qubit_confusion_map(): + q0, q1, q2 = cirq.LineQubit.range(3) + circuit = cirq.Circuit( + cirq.measure( + q0, + q1, + key='a', + confusion_map={ + (0, 1): np.array( + [ + [0.7, 0.1, 0.1, 0.1], + [0.1, 0.6, 0.1, 0.2], + [0.2, 0.2, 0.5, 0.1], + [0.0, 0.0, 1.0, 0.0], + ] + ) + }, + ), + cirq.X(q2).with_classical_controls('a'), + cirq.measure(q2, key='b'), + ) + deferred = cirq.defer_measurements(circuit) + sim = cirq.DensityMatrixSimulator() + result = sim.sample(deferred, repetitions=10_000) + + # The initial state is zero, so the first measurement will confuse by the first line in the + # map, giving 7000 0's, 1000 1's, 1000 2's, and 1000 3's, for a sum of 6000 on average. + assert 5_600 <= np.sum(result['a']) <= 6_400 + + # The measurement will be non-zero 3000 times on average. + assert 2_600 <= np.sum(result['b']) <= 3_400 + + # Try a deterministic one: initial state is 3, which the confusion map sends to 2 with p=1. + deferred.insert(0, cirq.X.on_each(q0, q1)) + result = sim.sample(deferred, repetitions=100) + assert np.sum(result['a']) == 200 + assert np.sum(result['b']) == 100 + + +def test_confusion_map_errors(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a', confusion_map={(0,): np.array([1])}), + cirq.X(q1).with_classical_controls('a'), + ) + with pytest.raises(ValueError, match='map must be 2D'): + _ = cirq.defer_measurements(circuit) + circuit = cirq.Circuit( + cirq.measure(q0, key='a', confusion_map={(0,): np.array([[0.7, 0.3]])}), + cirq.X(q1).with_classical_controls('a'), + ) + with pytest.raises(ValueError, match='map must be square'): + _ = cirq.defer_measurements(circuit) + circuit = cirq.Circuit( + cirq.measure( + q0, + key='a', + confusion_map={(0,): np.array([[0.7, 0.1, 0.2], [0.1, 0.6, 0.3], [0.2, 0.2, 0.6]])}, + ), + cirq.X(q1).with_classical_controls('a'), + ) + with pytest.raises(ValueError, match='size does not match'): + _ = cirq.defer_measurements(circuit) + circuit = cirq.Circuit( + cirq.measure(q0, key='a', confusion_map={(0,): np.array([[-1, 2], [0, 1]])}), + cirq.X(q1).with_classical_controls('a'), + ) + with pytest.raises(ValueError, match='negative probabilities'): + _ = cirq.defer_measurements(circuit) + circuit = cirq.Circuit( + cirq.measure(q0, key='a', confusion_map={(0,): np.array([[0.3, 0.3], [0.3, 0.3]])}), cirq.X(q1).with_classical_controls('a'), ) - with pytest.raises( - NotImplementedError, match='Deferring confused measurement is not implemented' - ): + with pytest.raises(ValueError, match='invalid probabilities'): _ = cirq.defer_measurements(circuit)