From 30dfe3154d0f876efe405a0fa62a1266714914cf Mon Sep 17 00:00:00 2001 From: Pavol Juhas Date: Mon, 11 Jul 2022 14:22:18 -0700 Subject: [PATCH] unleash confusion map (#5717) Revert "Guard confusion_map usage (#5534)" Assume measure() methods of SimulationState subclasses accept the `confusion_map` argument. Remove temporary workaround for methods not taking `confusion_map`. Also disable pylint check on missing argument which is required in the base class but unknown to subclasses actually called. This reverts commit 25388af7d96edb88ef320a08374e8b9b871126b3. --- cirq-core/cirq/ops/measurement_gate.py | 14 +---- cirq-core/cirq/ops/measurement_gate_test.py | 57 +-------------------- cirq-core/cirq/sim/simulation_state.py | 12 +++-- 3 files changed, 10 insertions(+), 73 deletions(-) diff --git a/cirq-core/cirq/ops/measurement_gate.py b/cirq-core/cirq/ops/measurement_gate.py index cf0c812c0db..40b3b2732f4 100644 --- a/cirq-core/cirq/ops/measurement_gate.py +++ b/cirq-core/cirq/ops/measurement_gate.py @@ -306,19 +306,7 @@ def _act_on_(self, sim_state: 'cirq.SimulationStateBase', qubits: Sequence['cirq if not isinstance(sim_state, SimulationState): return NotImplemented - try: - sim_state.measure( - qubits, self.key, self.full_invert_mask(), confusion_map=self.confusion_map - ) - except TypeError as e: - # Ensure that the error was due to confusion_map. - if not any("unexpected keyword argument 'confusion_map'" in arg for arg in e.args): - raise - _compat._warn_or_error( - "Starting in v0.16, SimulationState subclasses will be required to accept " - "a 'confusion_map' argument. See SimulationState.measure for details." - ) - sim_state.measure(qubits, self.key, self.full_invert_mask()) + sim_state.measure(qubits, self.key, self.full_invert_mask(), self.confusion_map) return True diff --git a/cirq-core/cirq/ops/measurement_gate_test.py b/cirq-core/cirq/ops/measurement_gate_test.py index deb21d3ef76..e45f5e0bb54 100644 --- a/cirq-core/cirq/ops/measurement_gate_test.py +++ b/cirq-core/cirq/ops/measurement_gate_test.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Sequence, Tuple, Union, cast +from typing import cast import numpy as np import pytest import cirq -from cirq.type_workarounds import NotImplementedType @pytest.mark.parametrize( @@ -504,57 +503,3 @@ def test_act_on_qutrit(): ) cirq.act_on(m, args) assert args.log_of_measurement_results == {'out': [0, 0]} - - -def test_act_on_no_confusion_map_deprecated(): - class OldSimState(cirq.StateVectorSimulationState): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.measured = False - - def _act_on_fallback_( - self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True - ) -> Union[bool, NotImplementedType]: - return NotImplemented # coverage: ignore - - def measure( # type: ignore - self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool] - ): - self.measured = True - - qubits = cirq.LineQubit.range(2) - old_state = OldSimState(qubits=qubits) - m = cirq.measure(*qubits, key='test') - with cirq.testing.assert_deprecated('confusion_map', deadline='v0.16'): - cirq.act_on(m, old_state) - assert old_state.measured - - -def test_act_on_no_confusion_map_scope_limited(): - error_msg = "error from deeper in measure" - - class ErrorProneSimState(cirq.StateVectorSimulationState): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.measured = False - - def _act_on_fallback_( - self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True - ) -> Union[bool, NotImplementedType]: - return NotImplemented # coverage: ignore - - def measure( - self, - qubits: Sequence['cirq.Qid'], - key: str, - invert_mask: Sequence[bool], - confusion_map: Optional[Dict[Tuple[int, ...], np.ndarray]] = None, - ): - raise TypeError(error_msg) - - # Verify that the check doesn't prevent other errors from being raised - qubits = cirq.LineQubit.range(2) - sv_state = ErrorProneSimState(qubits=qubits) - m = cirq.measure(*qubits, key='test') - with pytest.raises(TypeError, match=error_msg): - cirq.act_on(m, sv_state) diff --git a/cirq-core/cirq/sim/simulation_state.py b/cirq-core/cirq/sim/simulation_state.py index 911082d4927..7f14faa3118 100644 --- a/cirq-core/cirq/sim/simulation_state.py +++ b/cirq-core/cirq/sim/simulation_state.py @@ -82,7 +82,7 @@ def measure( qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool], - confusion_map: Optional[Dict[Tuple[int, ...], np.ndarray]] = None, + confusion_map: Dict[Tuple[int, ...], np.ndarray], ): """Measures the qubits and records to `log_of_measurement_results`. @@ -100,8 +100,7 @@ def measure( ValueError: If a measurement key has already been logged to a key. """ bits = self._perform_measurement(qubits) - if confusion_map is not None: - confused = self._confuse_result(bits, qubits, confusion_map) + confused = self._confuse_result(bits, qubits, confusion_map) corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(confused, invert_mask)] self._classical_data.record_measurement( value.MeasurementKey.parse_serialized(key), corrected, qubits @@ -184,10 +183,15 @@ def with_qubits(self: TSelf, qubits) -> TSelf: Args: qubits: The qubits to be added to the state space. - Regurns: + Returns: A new subclass object containing the extended state space. """ + # TODO(#5721): Fix inconsistent usage of the `state` argument in the + # SimulationState base (required) and in its derived classes (unknown + # in StateVectorSimulationState), then remove the pylint filter below. + # pylint: disable=missing-kwoa new_space = type(self)(qubits=qubits) # type: ignore + # pylint: enable=missing-kwoa return self.kronecker_product(new_space) def factor(