From 2d6c8bd858eb9dcfe214a2a6b48331b9ac253ee3 Mon Sep 17 00:00:00 2001 From: "Antoine (Tony) Bruguier" Date: Fri, 25 Feb 2022 08:01:02 -0800 Subject: [PATCH] Fail fast on measurements in has_unitary (#5020) * Fail fast on measurements in has_unitary * Unit test now also uses circuit * Implement has_unitary instead * Remove unused imports * Unit test that would blow up the memory --- cirq-core/cirq/ops/measurement_gate.py | 3 +++ cirq-core/cirq/ops/measurement_gate_test.py | 5 ++++ cirq-core/cirq/ops/pauli_measurement_gate.py | 3 +++ .../cirq/ops/pauli_measurement_gate_test.py | 5 ++++ .../protocols/has_unitary_protocol_test.py | 24 +++++++++++++++++++ 5 files changed, 40 insertions(+) diff --git a/cirq-core/cirq/ops/measurement_gate.py b/cirq-core/cirq/ops/measurement_gate.py index e7f65f69966..edd209905ff 100644 --- a/cirq-core/cirq/ops/measurement_gate.py +++ b/cirq-core/cirq/ops/measurement_gate.py @@ -84,6 +84,9 @@ def key(self, key: Union[str, 'cirq.MeasurementKey']): def _qid_shape_(self) -> Tuple[int, ...]: return self._qid_shape + def _has_unitary_(self) -> bool: + return False + def with_key(self, key: Union[str, 'cirq.MeasurementKey']) -> 'MeasurementGate': """Creates a measurement gate with a new key but otherwise identical.""" if key == self.key: diff --git a/cirq-core/cirq/ops/measurement_gate_test.py b/cirq-core/cirq/ops/measurement_gate_test.py index 2272b9a58b0..bdcfa2a8bac 100644 --- a/cirq-core/cirq/ops/measurement_gate_test.py +++ b/cirq-core/cirq/ops/measurement_gate_test.py @@ -59,6 +59,11 @@ def test_measure_init(num_qubits): cirq.MeasurementGate() +def test_measurement_has_unitary_returns_false(): + gate = cirq.MeasurementGate(1, 'a') + assert not cirq.has_unitary(gate) + + @pytest.mark.parametrize('num_qubits', [1, 2, 4]) def test_has_stabilizer_effect(num_qubits): assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits, 'a')) diff --git a/cirq-core/cirq/ops/pauli_measurement_gate.py b/cirq-core/cirq/ops/pauli_measurement_gate.py index 2b2667a00a3..b4e29e8b726 100644 --- a/cirq-core/cirq/ops/pauli_measurement_gate.py +++ b/cirq-core/cirq/ops/pauli_measurement_gate.py @@ -83,6 +83,9 @@ def key(self, key: Union[str, 'cirq.MeasurementKey']) -> None: def _qid_shape_(self) -> Tuple[int, ...]: return (2,) * len(self._observable) + def _has_unitary_(self) -> bool: + return False + def with_key(self, key: Union[str, 'cirq.MeasurementKey']) -> 'PauliMeasurementGate': """Creates a pauli measurement gate with a new key but otherwise identical.""" if key == self.key: diff --git a/cirq-core/cirq/ops/pauli_measurement_gate_test.py b/cirq-core/cirq/ops/pauli_measurement_gate_test.py index 030192bc00c..348f414f9cc 100644 --- a/cirq-core/cirq/ops/pauli_measurement_gate_test.py +++ b/cirq-core/cirq/ops/pauli_measurement_gate_test.py @@ -47,6 +47,11 @@ def test_init(observable, key): assert cirq.qid_shape(g) == (2,) * len(observable) +def test_measurement_has_unitary_returns_false(): + gate = cirq.PauliMeasurementGate([cirq.X], 'a') + assert not cirq.has_unitary(gate) + + def test_measurement_eq(): eq = cirq.testing.EqualsTester() eq.make_equality_group( diff --git a/cirq-core/cirq/protocols/has_unitary_protocol_test.py b/cirq-core/cirq/protocols/has_unitary_protocol_test.py index 64e1196fd90..f0222013801 100644 --- a/cirq-core/cirq/protocols/has_unitary_protocol_test.py +++ b/cirq-core/cirq/protocols/has_unitary_protocol_test.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import pytest import cirq @@ -26,6 +27,29 @@ class No: assert not cirq.has_unitary(No()) +@pytest.mark.parametrize( + 'measurement_gate', + ( + cirq.MeasurementGate(1, 'a'), + cirq.PauliMeasurementGate([cirq.X], 'a'), + ), +) +def test_fail_fast_measure(measurement_gate): + assert not cirq.has_unitary(measurement_gate) + + qubit = cirq.NamedQubit('q0') + circuit = cirq.Circuit() + circuit += measurement_gate(qubit) + circuit += cirq.H(qubit) + assert not cirq.has_unitary(circuit) + + +def test_fail_fast_measure_large_memory(): + num_qubits = 100 + measurement_op = cirq.MeasurementGate(num_qubits, 'a').on(*cirq.LineQubit.range(num_qubits)) + assert not cirq.has_unitary(measurement_op) + + def test_via_unitary(): class No1: def _unitary_(self):