Skip to content

Commit

Permalink
Fail fast on measurements in has_unitary (#5020)
Browse files Browse the repository at this point in the history
* Fail fast on measurements in has_unitary

* Unit test now also uses circuit

* Implement has_unitary instead

* Remove unused imports

* Unit test that would blow up the memory
  • Loading branch information
tonybruguier authored Feb 25, 2022
1 parent 43b37fa commit 2d6c8bd
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 0 deletions.
3 changes: 3 additions & 0 deletions cirq-core/cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def key(self, key: Union[str, 'cirq.MeasurementKey']):
def _qid_shape_(self) -> Tuple[int, ...]:
return self._qid_shape

def _has_unitary_(self) -> bool:
return False

def with_key(self, key: Union[str, 'cirq.MeasurementKey']) -> 'MeasurementGate':
"""Creates a measurement gate with a new key but otherwise identical."""
if key == self.key:
Expand Down
5 changes: 5 additions & 0 deletions cirq-core/cirq/ops/measurement_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def test_measure_init(num_qubits):
cirq.MeasurementGate()


def test_measurement_has_unitary_returns_false():
gate = cirq.MeasurementGate(1, 'a')
assert not cirq.has_unitary(gate)


@pytest.mark.parametrize('num_qubits', [1, 2, 4])
def test_has_stabilizer_effect(num_qubits):
assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits, 'a'))
Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/ops/pauli_measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def key(self, key: Union[str, 'cirq.MeasurementKey']) -> None:
def _qid_shape_(self) -> Tuple[int, ...]:
return (2,) * len(self._observable)

def _has_unitary_(self) -> bool:
return False

def with_key(self, key: Union[str, 'cirq.MeasurementKey']) -> 'PauliMeasurementGate':
"""Creates a pauli measurement gate with a new key but otherwise identical."""
if key == self.key:
Expand Down
5 changes: 5 additions & 0 deletions cirq-core/cirq/ops/pauli_measurement_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def test_init(observable, key):
assert cirq.qid_shape(g) == (2,) * len(observable)


def test_measurement_has_unitary_returns_false():
gate = cirq.PauliMeasurementGate([cirq.X], 'a')
assert not cirq.has_unitary(gate)


def test_measurement_eq():
eq = cirq.testing.EqualsTester()
eq.make_equality_group(
Expand Down
24 changes: 24 additions & 0 deletions cirq-core/cirq/protocols/has_unitary_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import numpy as np
import pytest

import cirq

Expand All @@ -26,6 +27,29 @@ class No:
assert not cirq.has_unitary(No())


@pytest.mark.parametrize(
'measurement_gate',
(
cirq.MeasurementGate(1, 'a'),
cirq.PauliMeasurementGate([cirq.X], 'a'),
),
)
def test_fail_fast_measure(measurement_gate):
assert not cirq.has_unitary(measurement_gate)

qubit = cirq.NamedQubit('q0')
circuit = cirq.Circuit()
circuit += measurement_gate(qubit)
circuit += cirq.H(qubit)
assert not cirq.has_unitary(circuit)


def test_fail_fast_measure_large_memory():
num_qubits = 100
measurement_op = cirq.MeasurementGate(num_qubits, 'a').on(*cirq.LineQubit.range(num_qubits))
assert not cirq.has_unitary(measurement_op)


def test_via_unitary():
class No1:
def _unitary_(self):
Expand Down

0 comments on commit 2d6c8bd

Please sign in to comment.