Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace fields with properties in ActOnArgs #5011

Merged
merged 2 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,23 @@
import inspect
from typing import (
Any,
cast,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
TypeVar,
TYPE_CHECKING,
Sequence,
Tuple,
cast,
Optional,
Iterator,
)
import warnings

import numpy as np

from cirq import ops, protocols, value
from cirq._compat import deprecated
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
from cirq.sim.operation_target import OperationTarget

Expand Down Expand Up @@ -74,7 +76,7 @@ def __init__(
if qubits is None:
qubits = ()
self._set_qubits(qubits)
self.prng = prng
self._prng = prng
self._classical_data = classical_data or value.ClassicalDataDictionaryStore(
_records={
value.MeasurementKey.parse_serialized(k): [tuple(v)]
Expand All @@ -83,9 +85,33 @@ def __init__(
)
self._ignore_measurement_results = ignore_measurement_results

@property
def prng(self) -> np.random.RandomState:
return self._prng

@property
def qubit_map(self) -> Mapping['cirq.Qid', int]:
return self._qubit_map

@prng.setter # type: ignore
@deprecated(
deadline="v0.15",
fix="The mutators of this class are deprecated, instantiate a new object instead.",
)
def prng(self, prng):
self._prng = prng

@qubit_map.setter # type: ignore
@deprecated(
deadline="v0.15",
fix="The mutators of this class are deprecated, instantiate a new object instead.",
)
def qubit_map(self, qubit_map):
self._qubit_map = qubit_map

def _set_qubits(self, qubits: Sequence['cirq.Qid']):
self._qubits = tuple(qubits)
self.qubit_map = {q: i for i, q in enumerate(self.qubits)}
self._qubit_map = {q: i for i, q in enumerate(self.qubits)}

def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool]):
"""Measures the qubits and records to `log_of_measurement_results`.
Expand Down Expand Up @@ -281,8 +307,7 @@ def swap(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False):
i2 = self.qubits.index(q2)
qubits = list(args.qubits)
qubits[i1], qubits[i2] = qubits[i2], qubits[i1]
args._qubits = tuple(qubits)
args.qubit_map = {q: i for i, q in enumerate(qubits)}
args._set_qubits(qubits)
return args

def rename(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False):
Expand All @@ -309,8 +334,7 @@ def rename(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False):
i1 = self.qubits.index(q1)
qubits = list(args.qubits)
qubits[i1] = q2
args._qubits = tuple(qubits)
args.qubit_map = {q: i for i, q in enumerate(qubits)}
args._set_qubits(qubits)
return args

def __getitem__(self: TSelf, item: Optional['cirq.Qid']) -> TSelf:
Expand Down
40 changes: 33 additions & 7 deletions cirq-core/cirq/sim/act_on_args_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Generic,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Expand All @@ -30,6 +31,7 @@
import numpy as np

from cirq import ops, protocols, value
from cirq._compat import deprecated
from cirq.sim.operation_target import OperationTarget
from cirq.sim.simulator import (
TActOnArgs,
Expand Down Expand Up @@ -68,16 +70,40 @@ def __init__(
classical_data: The shared classical data container for this
simulation.
"""
self.args = args
self._args = args
self._qubits = tuple(qubits)
self.split_untangled_states = split_untangled_states
self._split_untangled_states = split_untangled_states
self._classical_data = classical_data or value.ClassicalDataDictionaryStore(
_records={
value.MeasurementKey.parse_serialized(k): [tuple(v)]
for k, v in (log_of_measurement_results or {}).items()
}
)

@property
def args(self) -> Mapping[Optional['cirq.Qid'], TActOnArgs]:
return self._args

@property
def split_untangled_states(self) -> bool:
return self._split_untangled_states

@args.setter # type: ignore
@deprecated(
deadline="v0.15",
fix="The mutators of this class are deprecated, instantiate a new object instead.",
)
def args(self, args):
self._args = args

@split_untangled_states.setter # type: ignore
@deprecated(
deadline="v0.15",
fix="The mutators of this class are deprecated, instantiate a new object instead.",
)
def split_untangled_states(self, split_untangled_states):
self._split_untangled_states = split_untangled_states

def create_merged_state(self) -> TActOnArgs:
if not self.split_untangled_states:
return self.args[None]
Expand All @@ -104,8 +130,8 @@ def _act_on_fallback_(
if args0 is args1:
args0.swap(q0, q1, inplace=True)
else:
self.args[q0] = args1.rename(q1, q0, inplace=True)
self.args[q1] = args0.rename(q0, q1, inplace=True)
self._args[q0] = args1.rename(q1, q0, inplace=True)
self._args[q1] = args0.rename(q0, q1, inplace=True)
return True

# Go through the op's qubits and join any disparate ActOnArgs states
Expand All @@ -120,7 +146,7 @@ def _act_on_fallback_(

# (Backfill the args map with the new value)
for q in op_args.qubits:
self.args[q] = op_args
self._args[q] = op_args

# Act on the args with the operation
act_on_qubits = qubits if isinstance(action, ops.Gate) else None
Expand All @@ -134,11 +160,11 @@ def _act_on_fallback_(
for q in qubits:
if op_args.allows_factoring:
q_args, op_args = op_args.factor((q,), validate=False)
self.args[q] = q_args
self._args[q] = q_args

# (Backfill the args map with the new value)
for q in op_args.qubits:
self.args[q] = op_args
self._args[q] = op_args
return True

def copy(self, deep_copy_buffers: bool = True) -> 'cirq.ActOnArgsContainer[TActOnArgs]':
Expand Down
14 changes: 14 additions & 0 deletions cirq-core/cirq/sim/act_on_args_container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,17 @@ def test_act_on_gate_does_not_join():
assert len(set(args.values())) == 3
assert args[q0] is not args[q1]
assert args[q0] is not args[None]


def test_field_getters():
args = create_container(qs2)
assert args.args.keys() == set(qs2) | {None}
assert args.split_untangled_states


def test_field_setters_deprecated():
args = create_container(qs2)
with cirq.testing.assert_deprecated(deadline='v0.15'):
args.args = {}
with cirq.testing.assert_deprecated(deadline='v0.15'):
args.split_untangled_states = False
15 changes: 15 additions & 0 deletions cirq-core/cirq/sim/act_on_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Sequence, Union

import numpy as np
import pytest

import cirq
Expand Down Expand Up @@ -98,3 +99,17 @@ def test_on_copy_has_no_param():
args = DummyArgs()
with cirq.testing.assert_deprecated('deep_copy_buffers', deadline='0.15'):
args.copy(False)


def test_field_getters():
args = DummyArgs()
assert args.prng is np.random
assert args.qubit_map == {q: i for i, q in enumerate(cirq.LineQubit.range(2))}


def test_field_setters_deprecated():
args = DummyArgs()
with cirq.testing.assert_deprecated(deadline='v0.15'):
args.prng = 0
with cirq.testing.assert_deprecated(deadline='v0.15'):
args.qubit_map = {}