Skip to content

Commit

Permalink
Wrap private fields for DensePauliString (quantumlib#5064)
Browse files Browse the repository at this point in the history
quantumlib#4851 for `DensePauliString`. This was *slightly* less straightforward than others.

* Required moving `ALLOW_DEPRECATION_IN_TEST` or else a circular dependency was created.
* One place externally where setting the field still was necessary. (A factory method in `CliffordGate`).
* One member is of type `ndarray`, which is mutable, so we set `flags.writeable = False`
  • Loading branch information
daxfohl authored Mar 13, 2022
1 parent 6143210 commit e4ac8e6
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 12 deletions.
4 changes: 2 additions & 2 deletions cirq/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import sympy
import sympy.printing.repr

ALLOW_DEPRECATION_IN_TEST = 'ALLOW_DEPRECATION_IN_TEST'


try:
from functools import cached_property # pylint: disable=unused-import
Expand Down Expand Up @@ -144,8 +146,6 @@ def proper_eq(a: Any, b: Any) -> bool:


def _warn_or_error(msg):
from cirq.testing.deprecation import ALLOW_DEPRECATION_IN_TEST

deprecation_allowed = ALLOW_DEPRECATION_IN_TEST in os.environ
if _called_from_test() and not deprecation_allowed:
for filename, line_number, function_name, text in reversed(traceback.extract_stack()):
Expand Down
2 changes: 1 addition & 1 deletion cirq/ops/clifford_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def transform(self, pauli: Pauli) -> PauliTransform:
to = z_to
else:
to = x_to * z_to # Y = iXZ
to.coefficient *= 1j
to._coefficient *= 1j
# pauli_mask returns a value between 0 and 4 for [I, X, Y, Z].
to_gate = Pauli._XYZ[to.pauli_mask[0] - 1]
return PauliTransform(to=to_gate, flip=bool(to.coefficient != 1.0))
Expand Down
41 changes: 33 additions & 8 deletions cirq/ops/dense_pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import sympy

from cirq import protocols, linalg, value
from cirq._compat import proper_repr
from cirq._compat import deprecated, proper_repr
from cirq.ops import raw_types, identity, pauli_gates, global_phase_op, pauli_string
from cirq.type_workarounds import NotImplementedType

Expand Down Expand Up @@ -95,12 +95,37 @@ def __init__(
... coefficient=sympy.Symbol('t')))
t*IXYZ
"""
self.pauli_mask = _as_pauli_mask(pauli_mask)
self.coefficient = (
self._pauli_mask = _as_pauli_mask(pauli_mask)
self._coefficient = (
coefficient if isinstance(coefficient, sympy.Basic) else complex(coefficient)
)
if type(self) != MutableDensePauliString:
self.pauli_mask = np.copy(self.pauli_mask)
self._pauli_mask = np.copy(self.pauli_mask)
self._pauli_mask.flags.writeable = False

@property
def pauli_mask(self) -> np.ndarray:
return self._pauli_mask

@pauli_mask.setter # type: ignore
@deprecated(
deadline="v0.15",
fix="The mutators of this class are deprecated, instantiate a new object instead.",
)
def pauli_mask(self, pauli_mask: np.ndarray):
self._pauli_mask = pauli_mask

@property
def coefficient(self) -> complex:
return self._coefficient

@coefficient.setter # type: ignore
@deprecated(
deadline="v0.15",
fix="The mutators of this class are deprecated, instantiate a new object instead.",
)
def coefficient(self, coefficient: complex):
self._coefficient = coefficient

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['pauli_mask', 'coefficient'])
Expand Down Expand Up @@ -432,22 +457,22 @@ def __imul__(self, other):
f"other={repr(other)}"
)
self_mask = self.pauli_mask[: len(other.pauli_mask)]
self.coefficient *= _vectorized_pauli_mul_phase(self_mask, other.pauli_mask)
self.coefficient *= other.coefficient
self._coefficient *= _vectorized_pauli_mul_phase(self_mask, other.pauli_mask)
self._coefficient *= other.coefficient
self_mask ^= other.pauli_mask
return self

if isinstance(other, (sympy.Basic, numbers.Number)):
new_coef = protocols.mul(self.coefficient, other, default=None)
if new_coef is None:
return NotImplemented
self.coefficient = new_coef if isinstance(new_coef, sympy.Basic) else complex(new_coef)
self._coefficient = new_coef if isinstance(new_coef, sympy.Basic) else complex(new_coef)
return self

split = _attempt_value_to_pauli_index(other)
if split is not None:
p, i = split
self.coefficient *= _vectorized_pauli_mul_phase(self.pauli_mask[i], p)
self._coefficient *= _vectorized_pauli_mul_phase(self.pauli_mask[i], p)
self.pauli_mask[i] ^= p
return self

Expand Down
11 changes: 11 additions & 0 deletions cirq/ops/dense_pauli_string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,3 +658,14 @@ def test_symbolic():
assert p == cirq.MutableDensePauliString('XYZ', coefficient=t * r)
p /= r
assert p == cirq.MutableDensePauliString('XYZ', coefficient=t)


def test_setters_deprecated():
gate = cirq.DensePauliString('X')
mask = np.array([0, 3, 1, 2], dtype=np.uint8)
with cirq.testing.assert_deprecated('mutators', deadline='v0.15'):
gate.pauli_mask = mask
assert gate.pauli_mask is mask
with cirq.testing.assert_deprecated('mutators', deadline='v0.15'):
gate.coefficient = -1
assert gate.coefficient == -1
2 changes: 1 addition & 1 deletion cirq/testing/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
from typing import Iterator, Optional

ALLOW_DEPRECATION_IN_TEST = 'ALLOW_DEPRECATION_IN_TEST'
from cirq._compat import ALLOW_DEPRECATION_IN_TEST


@contextlib.contextmanager
Expand Down

0 comments on commit e4ac8e6

Please sign in to comment.