diff --git a/cirq-core/cirq/ops/measurement_gate.py b/cirq-core/cirq/ops/measurement_gate.py index 8071b414ac4..3913cb89e1d 100644 --- a/cirq-core/cirq/ops/measurement_gate.py +++ b/cirq-core/cirq/ops/measurement_gate.py @@ -312,7 +312,19 @@ def _act_on_(self, sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq if not isinstance(sim_state, SimulationState): return NotImplemented - sim_state.measure(qubits, self.key, self.full_invert_mask(), self.confusion_map) + try: + sim_state.measure( + qubits, self.key, self.full_invert_mask(), confusion_map=self.confusion_map + ) + except TypeError as e: + # Ensure that the error was due to confusion_map. + if not any("unexpected keyword argument 'confusion_map'" in arg for arg in e.args): + raise + _compat._warn_or_error( + "Starting in v0.16, SimulationState subclasses will be required to accept " + "a 'confusion_map' argument. See SimulationState.measure for details." + ) + sim_state.measure(qubits, self.key, self.full_invert_mask()) return True diff --git a/cirq-core/cirq/ops/measurement_gate_test.py b/cirq-core/cirq/ops/measurement_gate_test.py index ce17956a2f0..22d9239f724 100644 --- a/cirq-core/cirq/ops/measurement_gate_test.py +++ b/cirq-core/cirq/ops/measurement_gate_test.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast +from typing import Any, Dict, Optional, Sequence, Tuple, Union, cast import numpy as np import pytest import cirq +from cirq.type_workarounds import NotImplementedType @pytest.mark.parametrize( @@ -529,3 +530,57 @@ def test_act_on_qutrit(): ) cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [0, 0]} + + +def test_act_on_no_confusion_map_deprecated(): + class OldSimState(cirq.StateVectorSimulationState): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.measured = False + + def _act_on_fallback_( + self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True + ) -> Union[bool, NotImplementedType]: + return NotImplemented # coverage: ignore + + def measure( # type: ignore + self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool] + ): + self.measured = True + + qubits = cirq.LineQubit.range(2) + old_state = OldSimState(qubits=qubits) + m = cirq.measure(*qubits, key='test') + with cirq.testing.assert_deprecated('confusion_map', deadline='v0.16'): + cirq.act_on(m, old_state) + assert old_state.measured + + +def test_act_on_no_confusion_map_scope_limited(): + error_msg = "error from deeper in measure" + + class ErrorProneSimState(cirq.StateVectorSimulationState): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.measured = False + + def _act_on_fallback_( + self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True + ) -> Union[bool, NotImplementedType]: + return NotImplemented # coverage: ignore + + def measure( + self, + qubits: Sequence['cirq.Qid'], + key: str, + invert_mask: Sequence[bool], + confusion_map: Optional[Dict[Tuple[int, ...], np.ndarray]] = None, + ): + raise TypeError(error_msg) + + # Verify that the check doesn't prevent other errors from being raised + qubits = cirq.LineQubit.range(2) + sv_state = ErrorProneSimState(qubits=qubits) + m = cirq.measure(*qubits, key='test') + with pytest.raises(TypeError, match=error_msg): + cirq.act_on(m, sv_state) diff --git a/cirq-core/cirq/sim/simulation_state.py b/cirq-core/cirq/sim/simulation_state.py index eb8f2e37368..b7fcf884c00 100644 --- a/cirq-core/cirq/sim/simulation_state.py +++ b/cirq-core/cirq/sim/simulation_state.py @@ -105,7 +105,7 @@ def measure( qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool], - confusion_map: Dict[Tuple[int, ...], np.ndarray], + confusion_map: Optional[Dict[Tuple[int, ...], np.ndarray]] = None, ): """Measures the qubits and records to `log_of_measurement_results`. @@ -123,7 +123,8 @@ def measure( ValueError: If a measurement key has already been logged to a key. """ bits = self._perform_measurement(qubits) - confused = self._confuse_result(bits, qubits, confusion_map) + if confusion_map is not None: + confused = self._confuse_result(bits, qubits, confusion_map) corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(confused, invert_mask)] self._classical_data.record_measurement( value.MeasurementKey.parse_serialized(key), corrected, qubits