From af1267ddd854bd3c9f0a3579e376fb328afd9811 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Mon, 19 Dec 2022 11:54:13 -0800 Subject: [PATCH] Allow repeated measurements in deferred transformer (#5857) * Add handling for sympy conditions in deferred measurement transformer * docstring * mypy * mypy * cover * Make this more generic, covers all kinds of conditions. * Better docs * Sympy can also be CX * docs * docs * Allow repeated measurements in deferred transformer * Coverage * Add mixed tests, simplify loop, add simplification in ControlledGate * Fix error message * Simplify error message * Inline variable * fix merge * qudit sympy test * fix build * Fix test * Fix test * nits * mypy * mypy * mypy * Add some code comments * Add test for repeated measurement diagram * change test back Co-authored-by: Tanuj Khattar --- .../transformers/measurement_transformers.py | 89 +++++++++++-------- .../measurement_transformers_test.py | 68 ++++++++++---- 2 files changed, 105 insertions(+), 52 deletions(-) diff --git a/cirq-core/cirq/transformers/measurement_transformers.py b/cirq-core/cirq/transformers/measurement_transformers.py index 4df56a452dd..a7da2804dd3 100644 --- a/cirq-core/cirq/transformers/measurement_transformers.py +++ b/cirq-core/cirq/transformers/measurement_transformers.py @@ -13,18 +13,8 @@ # limitations under the License. import itertools -from typing import ( - Any, - Dict, - Iterable, - List, - Mapping, - Optional, - Sequence, - Tuple, - TYPE_CHECKING, - Union, -) +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union import numpy as np @@ -43,7 +33,7 @@ class _MeasurementQid(ops.Qid): Exactly one qubit will be created per qubit in the measurement gate. """ - def __init__(self, key: Union[str, 'cirq.MeasurementKey'], qid: 'cirq.Qid'): + def __init__(self, key: Union[str, 'cirq.MeasurementKey'], qid: 'cirq.Qid', index: int = 0): """Initializes the qubit. Args: @@ -51,22 +41,24 @@ def __init__(self, key: Union[str, 'cirq.MeasurementKey'], qid: 'cirq.Qid'): qid: One qubit that is being measured. Each deferred measurement should create one new _MeasurementQid per qubit being measured by that gate. + index: For repeated measurement keys, this represents the index of that measurement. """ self._key = value.MeasurementKey.parse_serialized(key) if isinstance(key, str) else key self._qid = qid + self._index = index @property def dimension(self) -> int: return self._qid.dimension def _comparison_key(self) -> Any: - return str(self._key), self._qid._comparison_key() + return str(self._key), self._index, self._qid._comparison_key() def __str__(self) -> str: - return f"M('{self._key}', q={self._qid})" + return f"M('{self._key}[{self._index}]', q={self._qid})" def __repr__(self) -> str: - return f'_MeasurementQid({self._key!r}, {self._qid!r})' + return f'_MeasurementQid({self._key!r}, {self._qid!r}, {self._index})' @transformer_api.transformer @@ -102,7 +94,9 @@ def defer_measurements( circuit = transformer_primitives.unroll_circuit_op(circuit, deep=True, tags_to_check=None) terminal_measurements = {op for _, op in find_terminal_measurements(circuit)} - measurement_qubits: Dict['cirq.MeasurementKey', List['_MeasurementQid']] = {} + measurement_qubits: Dict['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]] = defaultdict( + list + ) def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': if op in terminal_measurements: @@ -110,8 +104,8 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': gate = op.gate if isinstance(gate, ops.MeasurementGate): key = value.MeasurementKey.parse_serialized(gate.key) - targets = [_MeasurementQid(key, q) for q in op.qubits] - measurement_qubits[key] = targets + targets = [_MeasurementQid(key, q, len(measurement_qubits[key])) for q in op.qubits] + measurement_qubits[key].append(tuple(targets)) cxs = [_mod_add(q, target) for q, target in zip(op.qubits, targets)] confusions = [ _ConfusionChannel(m, [op.qubits[i].dimension for i in indexes]).on( @@ -125,10 +119,24 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': return [defer(op, None) for op in protocols.decompose_once(op)] elif op.classical_controls: # Convert to a quantum control - keys = sorted(set(key for c in op.classical_controls for key in c.keys)) - for key in keys: + + # First create a sorted set of the indexed keys for this control. + keys = sorted( + set( + indexed_key + for condition in op.classical_controls + for indexed_key in ( + [(condition.key, condition.index)] + if isinstance(condition, value.KeyCondition) + else [(k, -1) for k in condition.keys] + ) + ) + ) + for key, index in keys: if key not in measurement_qubits: raise ValueError(f'Deferred measurement for key={key} not found.') + if index >= len(measurement_qubits[key]) or index < -len(measurement_qubits[key]): + raise ValueError(f'Invalid index for {key}') # Try every possible datastore state (exponential in the number of keys) against the # condition, and the ones that work are the control values for the new op. @@ -140,12 +148,11 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': # Rearrange these into the format expected by SumOfProducts products = [ - [i for key in keys for i in store.records[key][0]] + [val for k, i in keys for val in store.records[k][i]] for store in compatible_datastores ] - control_values = ops.SumOfProducts(products) - qs = [q for key in keys for q in measurement_qubits[key]] + qs = [q for k, i in keys for q in measurement_qubits[k][i]] return op.without_classical_controls().controlled_by(*qs, control_values=control_values) return op @@ -155,14 +162,15 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE': tags_to_ignore=context.tags_to_ignore if context else (), raise_if_add_qubits=False, ).unfreeze() - for k, qubits in measurement_qubits.items(): - circuit.append(ops.measure(*qubits, key=k)) + for k, qubits_list in measurement_qubits.items(): + for qubits in qubits_list: + circuit.append(ops.measure(*qubits, key=k)) return circuit def _all_possible_datastore_states( - keys: Iterable['cirq.MeasurementKey'], - measurement_qubits: Mapping['cirq.MeasurementKey', Iterable['cirq.Qid']], + keys: Iterable[Tuple['cirq.MeasurementKey', int]], + measurement_qubits: Dict['cirq.MeasurementKey', List[Tuple['cirq.Qid', ...]]], ) -> Iterable['cirq.ClassicalDataStoreReader']: """The cartesian product of all possible DataStore states for the given keys.""" # First we get the list of all possible values. So if we have a key mapped to qubits of shape @@ -179,17 +187,28 @@ def _all_possible_datastore_states( # ((1, 1), (0,)), # ((1, 1), (1,)), # ((1, 1), (2,))] - all_values = itertools.product( + all_possible_measurements = itertools.product( *[ - tuple(itertools.product(*[range(q.dimension) for q in measurement_qubits[k]])) - for k in keys + tuple(itertools.product(*[range(q.dimension) for q in measurement_qubits[k][i]])) + for k, i in keys ] ) - # Then we create the ClassicalDataDictionaryStore for each of the above. - for sequences in all_values: - lookup = {k: [sequence] for k, sequence in zip(keys, sequences)} + # Then we create the ClassicalDataDictionaryStore for each of the above. A `measurement_list` + # is a single row of the above example, and can be zipped with `keys`. + for measurement_list in all_possible_measurements: + # Initialize a set of measurement records for this iteration. This will have the same shape + # as `measurement_qubits` but zeros for all measurements. + records = { + key: [(0,) * len(qubits) for qubits in qubits_list] + for key, qubits_list in measurement_qubits.items() + } + # Set the measurement values from the current row of the above, for each key/index we care + # about. + for (k, i), measurement in zip(keys, measurement_list): + records[k][i] = measurement + # Finally yield this sample to the consumer. yield value.ClassicalDataDictionaryStore( - _records=lookup, _measured_qubits={k: [tuple(measurement_qubits[k])] for k in keys} + _records=records, _measured_qubits=measurement_qubits ) diff --git a/cirq-core/cirq/transformers/measurement_transformers_test.py b/cirq-core/cirq/transformers/measurement_transformers_test.py index a5123d4a3c6..1290ce70606 100644 --- a/cirq-core/cirq/transformers/measurement_transformers_test.py +++ b/cirq-core/cirq/transformers/measurement_transformers_test.py @@ -445,6 +445,40 @@ def test_multi_qubit_control(): ) +@pytest.mark.parametrize('index', [-3, -2, -1, 0, 1, 2]) +def test_repeated(index: int): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.measure(q0, key='a'), # The control measurement when `index` is 0 or -2 + cirq.X(q0), + cirq.measure(q0, key='a'), # The control measurement when `index` is 1 or -1 + cirq.X(q1).with_classical_controls(cirq.KeyCondition(cirq.MeasurementKey('a'), index)), + cirq.measure(q1, key='b'), + ) + if index in [-3, 2]: + with pytest.raises(ValueError, match='Invalid index'): + _ = cirq.defer_measurements(circuit) + return + assert_equivalent_to_deferred(circuit) + deferred = cirq.defer_measurements(circuit) + q_ma = _MeasurementQid('a', q0) # The ancilla qubit created for the first `a` measurement + q_ma1 = _MeasurementQid('a', q0, 1) # The ancilla qubit created for the second `a` measurement + # The ancilla used for control should match the measurement used for control above. + q_expected_control = q_ma if index in [0, -2] else q_ma1 + cirq.testing.assert_same_circuits( + deferred, + cirq.Circuit( + cirq.CX(q0, q_ma), + cirq.X(q0), + cirq.CX(q0, q_ma1), + cirq.Moment(cirq.CX(q_expected_control, q1)), + cirq.measure(q_ma, key='a'), + cirq.measure(q_ma1, key='a'), + cirq.measure(q1, key='b'), + ), + ) + + def test_diagram(): q0, q1, q2, q3 = cirq.LineQubit.range(4) circuit = cirq.Circuit( @@ -457,23 +491,23 @@ def test_diagram(): cirq.testing.assert_has_diagram( deferred, """ - ┌────┐ -0: ─────────────────@───────X────────M('c')─── - │ │ -1: ─────────────────┼─@──────────────M──────── - │ │ │ -2: ─────────────────┼@┼──────────────M──────── - │││ │ -3: ─────────────────┼┼┼@─────────────M──────── - ││││ -M('a', q=q(0)): ────X┼┼┼────M('a')──────────── - │││ │ -M('a', q=q(2)): ─────X┼┼────M───────────────── - ││ -M('b', q=q(1)): ──────X┼────M('b')──────────── - │ │ -M('b', q=q(3)): ───────X────M───────────────── - └────┘ + ┌────┐ +0: ────────────────────@───────X────────M('c')─── + │ │ +1: ────────────────────┼─@──────────────M──────── + │ │ │ +2: ────────────────────┼@┼──────────────M──────── + │││ │ +3: ────────────────────┼┼┼@─────────────M──────── + ││││ +M('a[0]', q=q(0)): ────X┼┼┼────M('a')──────────── + │││ │ +M('a[0]', q=q(2)): ─────X┼┼────M───────────────── + ││ +M('b[0]', q=q(1)): ──────X┼────M('b')──────────── + │ │ +M('b[0]', q=q(3)): ───────X────M───────────────── + └────┘ """, use_unicode_characters=True, )