From e9b62583ec58babf8d8fac0b0fea7ad0d3e515d1 Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Tue, 21 Dec 2021 13:14:36 -0800 Subject: [PATCH] Sanitize type annotations in cirq.sim (#4773) * Sanitize type annotations in cirq.sim * keys --- cirq-core/cirq/circuits/circuit.py | 8 +- cirq-core/cirq/circuits/circuit_operation.py | 6 +- cirq-core/cirq/circuits/frozen_circuit.py | 12 +-- .../cirq/contrib/qcircuit/qcircuit_diagram.py | 2 +- cirq-core/cirq/contrib/quimb/mps_simulator.py | 12 +-- .../ops/classically_controlled_operation.py | 4 +- cirq-core/cirq/ops/gate_operation.py | 4 +- cirq-core/cirq/ops/kraus_channel.py | 13 ++- cirq-core/cirq/ops/measure_util.py | 6 +- cirq-core/cirq/ops/measurement_gate.py | 8 +- cirq-core/cirq/ops/mixed_unitary_channel.py | 13 ++- cirq-core/cirq/ops/moment.py | 9 +- cirq-core/cirq/ops/pauli_measurement_gate.py | 8 +- cirq-core/cirq/ops/raw_types.py | 4 +- .../protocols/measurement_key_protocol.py | 17 ++-- cirq-core/cirq/sim/act_on_args.py | 2 +- cirq-core/cirq/sim/act_on_args_container.py | 4 +- .../cirq/sim/act_on_density_matrix_args.py | 12 +-- .../cirq/sim/act_on_state_vector_args.py | 12 ++- .../clifford/act_on_clifford_tableau_args.py | 3 +- .../act_on_stabilizer_ch_form_args.py | 3 +- .../cirq/sim/clifford/clifford_simulator.py | 26 +++--- .../cirq/sim/clifford/stabilizer_sampler.py | 4 +- .../cirq/sim/density_matrix_simulator.py | 32 +++---- cirq-core/cirq/sim/mux.py | 16 ++-- cirq-core/cirq/sim/simulator.py | 92 +++++++++---------- cirq-core/cirq/sim/simulator_base.py | 14 +-- cirq-core/cirq/sim/sparse_simulator.py | 6 +- cirq-core/cirq/sim/state_vector.py | 11 +-- cirq-core/cirq/sim/state_vector_simulator.py | 23 +++-- 30 files changed, 195 insertions(+), 191 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 06a634e2291..abb23189fdb 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -50,7 +50,7 @@ import numpy as np import cirq._version -from cirq import devices, ops, protocols, value, qis +from cirq import devices, ops, protocols, qis from cirq.circuits._bucket_priority_queue import BucketPriorityQueue from cirq.circuits.circuit_operation import CircuitOperation from cirq.circuits.insert_strategy import InsertStrategy @@ -886,10 +886,10 @@ def qid_shape( qids = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits()) return protocols.qid_shape(qids) - def all_measurement_key_objs(self) -> AbstractSet[value.MeasurementKey]: + def all_measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']: return {key for op in self.all_operations() for key in protocols.measurement_key_objs(op)} - def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]: + def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: return self.all_measurement_key_objs() def all_measurement_key_names(self) -> AbstractSet[str]: @@ -1537,7 +1537,7 @@ def factorize(self: CIRCUIT_TYPE) -> Iterable[CIRCUIT_TYPE]: self._with_sliced_moments([m[qubits] for m in self.moments]) for qubits in qubit_factors ) - def _control_keys_(self) -> FrozenSet[value.MeasurementKey]: + def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: controls = frozenset(k for op in self.all_operations() for k in protocols.control_keys(op)) return controls - protocols.measurement_key_objs(self) diff --git a/cirq-core/cirq/circuits/circuit_operation.py b/cirq-core/cirq/circuits/circuit_operation.py index 15c9ef4af10..d10857d1a5f 100644 --- a/cirq-core/cirq/circuits/circuit_operation.py +++ b/cirq-core/cirq/circuits/circuit_operation.py @@ -93,7 +93,7 @@ class CircuitOperation(ops.Operation): """ _hash: Optional[int] = dataclasses.field(default=None, init=False) - _cached_measurement_key_objs: Optional[AbstractSet[value.MeasurementKey]] = dataclasses.field( + _cached_measurement_key_objs: Optional[AbstractSet['cirq.MeasurementKey']] = dataclasses.field( default=None, init=False ) @@ -184,7 +184,7 @@ def _qid_shape_(self) -> Tuple[int, ...]: def _is_measurement_(self) -> bool: return self.circuit._is_measurement_() - def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]: + def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: if self._cached_measurement_key_objs is None: circuit_keys = protocols.measurement_key_objs(self.circuit) if self.repetition_ids is not None: @@ -207,7 +207,7 @@ def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]: def _measurement_key_names_(self) -> AbstractSet[str]: return {str(key) for key in self._measurement_key_objs_()} - def _control_keys_(self) -> AbstractSet[value.MeasurementKey]: + def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']: if not protocols.control_keys(self.circuit): return frozenset() return protocols.control_keys(self.mapped_circuit()) diff --git a/cirq-core/cirq/circuits/frozen_circuit.py b/cirq-core/cirq/circuits/frozen_circuit.py index 50f2fe9b4e2..be21ae0f95f 100644 --- a/cirq-core/cirq/circuits/frozen_circuit.py +++ b/cirq-core/cirq/circuits/frozen_circuit.py @@ -27,7 +27,7 @@ import numpy as np -from cirq import devices, ops, protocols, value +from cirq import devices, ops, protocols from cirq.circuits import AbstractCircuit, Alignment, Circuit from cirq.circuits.insert_strategy import InsertStrategy from cirq.type_workarounds import NotImplementedType @@ -74,9 +74,9 @@ def __init__( self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None self._all_operations: Optional[Tuple[ops.Operation, ...]] = None self._has_measurements: Optional[bool] = None - self._all_measurement_key_objs: Optional[AbstractSet[value.MeasurementKey]] = None + self._all_measurement_key_objs: Optional[AbstractSet['cirq.MeasurementKey']] = None self._are_all_measurements_terminal: Optional[bool] = None - self._control_keys: Optional[FrozenSet[value.MeasurementKey]] = None + self._control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = None @property def moments(self) -> Sequence['cirq.Moment']: @@ -126,15 +126,15 @@ def has_measurements(self) -> bool: self._has_measurements = super().has_measurements() return self._has_measurements - def all_measurement_key_objs(self) -> AbstractSet[value.MeasurementKey]: + def all_measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']: if self._all_measurement_key_objs is None: self._all_measurement_key_objs = super().all_measurement_key_objs() return self._all_measurement_key_objs - def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]: + def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: return self.all_measurement_key_objs() - def _control_keys_(self) -> FrozenSet[value.MeasurementKey]: + def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: if self._control_keys is None: self._control_keys = super()._control_keys_() return self._control_keys diff --git a/cirq-core/cirq/contrib/qcircuit/qcircuit_diagram.py b/cirq-core/cirq/contrib/qcircuit/qcircuit_diagram.py index 05767ac97c1..00d9219a7e5 100644 --- a/cirq-core/cirq/contrib/qcircuit/qcircuit_diagram.py +++ b/cirq-core/cirq/contrib/qcircuit/qcircuit_diagram.py @@ -64,7 +64,7 @@ def _render(diagram: circuits.TextDiagramDrawer) -> str: def circuit_to_latex_using_qcircuit( - circuit: circuits.Circuit, qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT + circuit: 'cirq.Circuit', qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT ) -> str: """Returns a QCircuit-based latex diagram of the given circuit. diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator.py b/cirq-core/cirq/contrib/quimb/mps_simulator.py index 1073672f4dc..14f3b35ba1c 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator.py @@ -24,7 +24,7 @@ import numpy as np import quimb.tensor as qtn -from cirq import devices, study, ops, protocols, value +from cirq import devices, ops, protocols, value from cirq.sim import simulator_base from cirq.sim.act_on_args import ActOnArgs @@ -126,7 +126,7 @@ def _create_step_result( def _create_simulator_trial_result( self, - params: study.ParamResolver, + params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], final_step_result: 'MPSSimulatorStepResult', ) -> 'MPSTrialResult': @@ -151,7 +151,7 @@ class MPSTrialResult(simulator_base.SimulationTrialResultBase['MPSState', 'MPSSt def __init__( self, - params: study.ParamResolver, + params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], final_step_result: 'MPSSimulatorStepResult', ) -> None: @@ -321,7 +321,7 @@ def state_vector(self) -> np.ndarray: sorted_ind = tuple(sorted(state_vector.inds)) return state_vector.fuse({'i': sorted_ind}).data - def partial_trace(self, keep_qubits: Set[ops.Qid]) -> np.ndarray: + def partial_trace(self, keep_qubits: Set['cirq.Qid']) -> np.ndarray: """Traces out all qubits except keep_qubits. Args: @@ -475,7 +475,7 @@ def estimation_stats(self): } def perform_measurement( - self, qubits: Sequence[ops.Qid], prng: np.random.RandomState, collapse_state_vector=True + self, qubits: Sequence['cirq.Qid'], prng: np.random.RandomState, collapse_state_vector=True ) -> List[int]: """Performs a measurement over one or more qubits. @@ -533,7 +533,7 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: def sample( self, - qubits: Sequence[ops.Qid], + qubits: Sequence['cirq.Qid'], repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> np.ndarray: diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index d4da88fdcc2..fe3093a3d4e 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -184,7 +184,7 @@ def _with_rescoped_keys_( path: Tuple[str, ...], bindable_keys: FrozenSet['cirq.MeasurementKey'], ) -> 'ClassicallyControlledOperation': - def map_key(key: value.MeasurementKey) -> value.MeasurementKey: + def map_key(key: 'cirq.MeasurementKey') -> 'cirq.MeasurementKey': for i in range(len(path) + 1): back_path = path[: len(path) - i] new_key = key.with_key_path_prefix(*back_path) @@ -195,7 +195,7 @@ def map_key(key: value.MeasurementKey) -> value.MeasurementKey: sub_operation = protocols.with_rescoped_keys(self._sub_operation, path, bindable_keys) return sub_operation.with_classical_controls(*[map_key(k) for k in self._control_keys]) - def _control_keys_(self) -> FrozenSet[value.MeasurementKey]: + def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: return frozenset(self._control_keys).union(protocols.control_keys(self._sub_operation)) def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]: diff --git a/cirq-core/cirq/ops/gate_operation.py b/cirq-core/cirq/ops/gate_operation.py index 706d91d2f0c..ef258381eda 100644 --- a/cirq-core/cirq/ops/gate_operation.py +++ b/cirq-core/cirq/ops/gate_operation.py @@ -248,13 +248,13 @@ def _measurement_key_names_(self) -> Optional[AbstractSet[str]]: return getter() return NotImplemented - def _measurement_key_obj_(self) -> Optional[value.MeasurementKey]: + def _measurement_key_obj_(self) -> Optional['cirq.MeasurementKey']: getter = getattr(self.gate, '_measurement_key_obj_', None) if getter is not None: return getter() return NotImplemented - def _measurement_key_objs_(self) -> Optional[AbstractSet[value.MeasurementKey]]: + def _measurement_key_objs_(self) -> Optional[AbstractSet['cirq.MeasurementKey']]: getter = getattr(self.gate, '_measurement_key_objs_', None) if getter is not None: return getter() diff --git a/cirq-core/cirq/ops/kraus_channel.py b/cirq-core/cirq/ops/kraus_channel.py index 402606e61c5..6dadbd5872b 100644 --- a/cirq-core/cirq/ops/kraus_channel.py +++ b/cirq-core/cirq/ops/kraus_channel.py @@ -1,11 +1,14 @@ # pylint: disable=wrong-or-nonexistent-copyright-notice -from typing import Any, Dict, FrozenSet, Iterable, Tuple, Union +from typing import Any, Dict, FrozenSet, Iterable, Tuple, TYPE_CHECKING, Union import numpy as np from cirq import linalg, protocols, value from cirq._compat import proper_repr from cirq.ops import raw_types +if TYPE_CHECKING: + import cirq + # TODO(#3241): support qudits and non-square operators. class KrausChannel(raw_types.Gate): @@ -25,7 +28,7 @@ class KrausChannel(raw_types.Gate): def __init__( self, kraus_ops: Iterable[np.ndarray], - key: Union[str, value.MeasurementKey, None] = None, + key: Union[str, 'cirq.MeasurementKey', None] = None, validate: bool = False, ): kraus_ops = list(kraus_ops) @@ -52,7 +55,7 @@ def __init__( self._key = key @staticmethod - def from_channel(channel: 'KrausChannel', key: Union[str, value.MeasurementKey, None] = None): + def from_channel(channel: 'KrausChannel', key: Union[str, 'cirq.MeasurementKey', None] = None): """Creates a copy of a channel with the given measurement key.""" return KrausChannel(kraus_ops=list(protocols.kraus(channel)), key=key) @@ -76,7 +79,7 @@ def _measurement_key_name_(self) -> str: return NotImplemented return str(self._key) - def _measurement_key_obj_(self) -> value.MeasurementKey: + def _measurement_key_obj_(self) -> 'cirq.MeasurementKey': if self._key is None: return NotImplemented return self._key @@ -99,7 +102,7 @@ def _with_key_path_prefix_(self, prefix: Tuple[str, ...]): def _with_rescoped_keys_( self, path: Tuple[str, ...], - bindable_keys: FrozenSet[value.MeasurementKey], + bindable_keys: FrozenSet['cirq.MeasurementKey'], ): return KrausChannel( kraus_ops=self._kraus_ops, diff --git a/cirq-core/cirq/ops/measure_util.py b/cirq-core/cirq/ops/measure_util.py index 5ab330290e2..1678949771a 100644 --- a/cirq-core/cirq/ops/measure_util.py +++ b/cirq-core/cirq/ops/measure_util.py @@ -16,7 +16,7 @@ import numpy as np -from cirq import protocols, value +from cirq import protocols from cirq.ops import raw_types, pauli_string from cirq.ops.measurement_gate import MeasurementGate from cirq.ops.pauli_measurement_gate import PauliMeasurementGate @@ -31,7 +31,7 @@ def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str: def measure_single_paulistring( pauli_observable: pauli_string.PauliString, - key: Optional[Union[str, value.MeasurementKey]] = None, + key: Optional[Union[str, 'cirq.MeasurementKey']] = None, ) -> raw_types.Operation: """Returns a single PauliMeasurementGate which measures the pauli observable @@ -83,7 +83,7 @@ def measure_paulistring_terms( def measure( *target: 'cirq.Qid', - key: Optional[Union[str, value.MeasurementKey]] = None, + key: Optional[Union[str, 'cirq.MeasurementKey']] = None, invert_mask: Tuple[bool, ...] = (), ) -> raw_types.Operation: """Returns a single MeasurementGate applied to all the given qubits. diff --git a/cirq-core/cirq/ops/measurement_gate.py b/cirq-core/cirq/ops/measurement_gate.py index 58eecea0d3f..e7f65f69966 100644 --- a/cirq-core/cirq/ops/measurement_gate.py +++ b/cirq-core/cirq/ops/measurement_gate.py @@ -34,7 +34,7 @@ class MeasurementGate(raw_types.Gate): def __init__( self, num_qubits: Optional[int] = None, - key: Union[str, value.MeasurementKey] = '', + key: Union[str, 'cirq.MeasurementKey'] = '', invert_mask: Tuple[bool, ...] = (), qid_shape: Tuple[int, ...] = None, ) -> None: @@ -75,7 +75,7 @@ def key(self) -> str: return str(self.mkey) @key.setter - def key(self, key: Union[str, value.MeasurementKey]): + def key(self, key: Union[str, 'cirq.MeasurementKey']): if isinstance(key, value.MeasurementKey): self.mkey = key else: @@ -84,7 +84,7 @@ def key(self, key: Union[str, value.MeasurementKey]): def _qid_shape_(self) -> Tuple[int, ...]: return self._qid_shape - def with_key(self, key: Union[str, value.MeasurementKey]) -> 'MeasurementGate': + def with_key(self, key: Union[str, 'cirq.MeasurementKey']) -> 'MeasurementGate': """Creates a measurement gate with a new key but otherwise identical.""" if key == self.key: return self @@ -139,7 +139,7 @@ def _is_measurement_(self) -> bool: def _measurement_key_name_(self) -> str: return self.key - def _measurement_key_obj_(self) -> value.MeasurementKey: + def _measurement_key_obj_(self) -> 'cirq.MeasurementKey': return self.mkey def _kraus_(self): diff --git a/cirq-core/cirq/ops/mixed_unitary_channel.py b/cirq-core/cirq/ops/mixed_unitary_channel.py index eb03e78f5d6..5f5bdb9b9e2 100644 --- a/cirq-core/cirq/ops/mixed_unitary_channel.py +++ b/cirq-core/cirq/ops/mixed_unitary_channel.py @@ -1,11 +1,14 @@ # pylint: disable=wrong-or-nonexistent-copyright-notice -from typing import Any, Dict, FrozenSet, Iterable, Tuple, Union +from typing import Any, Dict, FrozenSet, Iterable, Tuple, TYPE_CHECKING, Union import numpy as np from cirq import linalg, protocols, value from cirq._compat import proper_repr from cirq.ops import raw_types +if TYPE_CHECKING: + import cirq + class MixedUnitaryChannel(raw_types.Gate): """A generic mixture that can record the index of its selected operator. @@ -24,7 +27,7 @@ class MixedUnitaryChannel(raw_types.Gate): def __init__( self, mixture: Iterable[Tuple[float, np.ndarray]], - key: Union[str, value.MeasurementKey, None] = None, + key: Union[str, 'cirq.MeasurementKey', None] = None, validate: bool = False, ): mixture = list(mixture) @@ -54,7 +57,7 @@ def __init__( @staticmethod def from_mixture( - mixture: 'protocols.SupportsMixture', key: Union[str, value.MeasurementKey, None] = None + mixture: 'protocols.SupportsMixture', key: Union[str, 'cirq.MeasurementKey', None] = None ): """Creates a copy of a mixture with the given measurement key.""" return MixedUnitaryChannel(mixture=list(protocols.mixture(mixture)), key=key) @@ -85,7 +88,7 @@ def _measurement_key_name_(self) -> str: return NotImplemented return str(self._key) - def _measurement_key_obj_(self) -> value.MeasurementKey: + def _measurement_key_obj_(self) -> 'cirq.MeasurementKey': if self._key is None: return NotImplemented return self._key @@ -110,7 +113,7 @@ def _with_key_path_prefix_(self, prefix: Tuple[str, ...]): def _with_rescoped_keys_( self, path: Tuple[str, ...], - bindable_keys: FrozenSet[value.MeasurementKey], + bindable_keys: FrozenSet['cirq.MeasurementKey'], ): return MixedUnitaryChannel( mixture=self._mixture, diff --git a/cirq-core/cirq/ops/moment.py b/cirq-core/cirq/ops/moment.py index 92e4de47bef..76020927bd5 100644 --- a/cirq-core/cirq/ops/moment.py +++ b/cirq-core/cirq/ops/moment.py @@ -14,6 +14,7 @@ """A simplified time-slice of operations within a sequenced circuit.""" +import itertools from typing import ( AbstractSet, Any, @@ -31,11 +32,9 @@ Union, ) -import itertools - import numpy as np -from cirq import protocols, ops, qis, value +from cirq import protocols, ops, qis from cirq._import import LazyLoader from cirq.ops import raw_types from cirq.protocols import circuit_diagram_info_protocol @@ -103,7 +102,7 @@ def __init__(self, *contents: 'cirq.OP_TREE') -> None: self._qubit_to_op[q] = op self._qubits = frozenset(self._qubit_to_op.keys()) - self._measurement_key_objs: Optional[AbstractSet[value.MeasurementKey]] = None + self._measurement_key_objs: Optional[AbstractSet['cirq.MeasurementKey']] = None @property def operations(self) -> Tuple['cirq.Operation', ...]: @@ -233,7 +232,7 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): def _measurement_key_names_(self) -> AbstractSet[str]: return {str(key) for key in self._measurement_key_objs_()} - def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]: + def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: if self._measurement_key_objs is None: self._measurement_key_objs = { key for op in self.operations for key in protocols.measurement_key_objs(op) diff --git a/cirq-core/cirq/ops/pauli_measurement_gate.py b/cirq-core/cirq/ops/pauli_measurement_gate.py index c383ff20030..8f9a7358f75 100644 --- a/cirq-core/cirq/ops/pauli_measurement_gate.py +++ b/cirq-core/cirq/ops/pauli_measurement_gate.py @@ -39,7 +39,7 @@ class PauliMeasurementGate(raw_types.Gate): def __init__( self, observable: Iterable['cirq.Pauli'], - key: Union[str, value.MeasurementKey] = '', + key: Union[str, 'cirq.MeasurementKey'] = '', ) -> None: """Inits PauliMeasurementGate. @@ -64,7 +64,7 @@ def key(self) -> str: return str(self.mkey) @key.setter - def key(self, key: Union[str, value.MeasurementKey]) -> None: + def key(self, key: Union[str, 'cirq.MeasurementKey']) -> None: if isinstance(key, str): key = value.MeasurementKey(name=key) self.mkey = key @@ -72,7 +72,7 @@ def key(self, key: Union[str, value.MeasurementKey]) -> None: def _qid_shape_(self) -> Tuple[int, ...]: return (2,) * len(self._observable) - def with_key(self, key: Union[str, value.MeasurementKey]) -> 'PauliMeasurementGate': + def with_key(self, key: Union[str, 'cirq.MeasurementKey']) -> 'PauliMeasurementGate': """Creates a pauli measurement gate with a new key but otherwise identical.""" if key == self.key: return self @@ -106,7 +106,7 @@ def _is_measurement_(self) -> bool: def _measurement_key_name_(self) -> str: return self.key - def _measurement_key_obj_(self) -> value.MeasurementKey: + def _measurement_key_obj_(self) -> 'cirq.MeasurementKey': return self.mkey def observable(self) -> 'cirq.DensePauliString': diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index 8212ba23f3e..c88ba65eeb8 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -754,7 +754,7 @@ def _kraus_(self) -> Union[Tuple[np.ndarray], NotImplementedType]: def _measurement_key_names_(self) -> AbstractSet[str]: return protocols.measurement_key_names(self.sub_operation) - def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]: + def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: return protocols.measurement_key_objs(self.sub_operation) def _is_measurement_(self) -> bool: @@ -825,7 +825,7 @@ def without_classical_controls(self) -> 'cirq.Operation': new_sub_operation = self.sub_operation.without_classical_controls() return self if new_sub_operation is self.sub_operation else new_sub_operation - def _control_keys_(self) -> AbstractSet[value.MeasurementKey]: + def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']: return protocols.control_keys(self.sub_operation) diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index fc398d2cf61..639df1aa180 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -13,12 +13,15 @@ # limitations under the License. """Protocol for object that have measurement keys.""" -from typing import AbstractSet, Any, Dict, FrozenSet, Optional, Tuple +from typing import AbstractSet, Any, Dict, FrozenSet, Optional, Tuple, TYPE_CHECKING from typing_extensions import Protocol -from cirq._doc import doc_private from cirq import value +from cirq._doc import doc_private + +if TYPE_CHECKING: + import cirq # This is a special indicator value used by the inverse method to determine # whether or not the caller provided a 'default' argument. @@ -56,7 +59,7 @@ def _is_measurement_(self) -> bool: """Return if this object is (or contains) a measurement.""" @doc_private - def _measurement_key_obj_(self) -> value.MeasurementKey: + def _measurement_key_obj_(self) -> 'cirq.MeasurementKey': """Return the key object that will be used to identify this measurement. When a measurement occurs, either on hardware, or in a simulation, @@ -65,7 +68,7 @@ def _measurement_key_obj_(self) -> value.MeasurementKey: """ @doc_private - def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]: + def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']: """Return the key objects for measurements performed by the receiving object. When a measurement occurs, either on hardware, or in a simulation, @@ -169,7 +172,7 @@ def measurement_key_name(val: Any, default: Any = RaiseTypeErrorIfNotProvided): def _measurement_key_objs_from_magic_methods( val: Any, -) -> Optional[AbstractSet[value.MeasurementKey]]: +) -> Optional[AbstractSet['cirq.MeasurementKey']]: """Uses the measurement key related magic methods to get the `MeasurementKey`s for this object.""" @@ -201,7 +204,7 @@ def _measurement_key_names_from_magic_methods(val: Any) -> Optional[AbstractSet[ return result -def measurement_key_objs(val: Any) -> AbstractSet[value.MeasurementKey]: +def measurement_key_objs(val: Any) -> AbstractSet['cirq.MeasurementKey']: """Gets the measurement key objects of measurements within the given value. Args: @@ -310,7 +313,7 @@ def with_key_path_prefix(val: Any, prefix: Tuple[str, ...]): def with_rescoped_keys( val: Any, path: Tuple[str, ...], - bindable_keys: FrozenSet[value.MeasurementKey] = None, + bindable_keys: FrozenSet['cirq.MeasurementKey'] = None, ): """Rescopes any measurement and control keys to the provided path, given the existing keys. diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index d28cda196f8..5a21502cb8f 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -262,7 +262,7 @@ def __iter__(self) -> Iterator[Optional['cirq.Qid']]: def strat_act_on_from_apply_decompose( val: Any, - args: ActOnArgs, + args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid'], ) -> bool: operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val) diff --git a/cirq-core/cirq/sim/act_on_args_container.py b/cirq-core/cirq/sim/act_on_args_container.py index 4433c0ac2bd..8e3d04c952d 100644 --- a/cirq-core/cirq/sim/act_on_args_container.py +++ b/cirq-core/cirq/sim/act_on_args_container.py @@ -130,7 +130,7 @@ def _act_on_fallback_( self.args[q] = op_args return True - def copy(self) -> 'ActOnArgsContainer[TActOnArgs]': + def copy(self) -> 'cirq.ActOnArgsContainer[TActOnArgs]': logs = self.log_of_measurement_results.copy() copies = {a: a.copy() for a in set(self.args.values())} for copy in copies.values(): @@ -148,7 +148,7 @@ def log_of_measurement_results(self) -> Dict[str, Any]: def sample( self, - qubits: List[ops.Qid], + qubits: List['cirq.Qid'], repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> np.ndarray: diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index 2e9734954cd..e32e558cebf 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -103,12 +103,12 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: ) return bits - def _on_copy(self, target: 'ActOnDensityMatrixArgs'): + def _on_copy(self, target: 'cirq.ActOnDensityMatrixArgs'): target.target_tensor = self.target_tensor.copy() target.available_buffer = [b.copy() for b in self.available_buffer] def _on_kronecker_product( - self, other: 'ActOnDensityMatrixArgs', target: 'ActOnDensityMatrixArgs' + self, other: 'cirq.ActOnDensityMatrixArgs', target: 'cirq.ActOnDensityMatrixArgs' ): target_tensor = transformations.density_matrix_kronecker_product( self.target_tensor, other.target_tensor @@ -122,8 +122,8 @@ def _on_kronecker_product( def _on_factor( self, qubits: Sequence['cirq.Qid'], - extracted: 'ActOnDensityMatrixArgs', - remainder: 'ActOnDensityMatrixArgs', + extracted: 'cirq.ActOnDensityMatrixArgs', + remainder: 'cirq.ActOnDensityMatrixArgs', validate=True, atol=1e-07, ): @@ -143,7 +143,7 @@ def _on_factor( remainder.qid_shape = remainder_tensor.shape[: int(remainder_tensor.ndim / 2)] def _on_transpose_to_qubit_order( - self, qubits: Sequence['cirq.Qid'], target: 'ActOnDensityMatrixArgs' + self, qubits: Sequence['cirq.Qid'], target: 'cirq.ActOnDensityMatrixArgs' ): axes = self.get_axes(qubits) new_tensor = transformations.transpose_density_matrix_to_axis_order( @@ -181,7 +181,7 @@ def __repr__(self) -> str: def _strat_apply_channel_to_state( - action: Any, args: ActOnDensityMatrixArgs, qubits: Sequence['cirq.Qid'] + action: Any, args: 'cirq.ActOnDensityMatrixArgs', qubits: Sequence['cirq.Qid'] ) -> bool: """Apply channel to state.""" axes = args.get_axes(qubits) diff --git a/cirq-core/cirq/sim/act_on_state_vector_args.py b/cirq-core/cirq/sim/act_on_state_vector_args.py index a523047c71d..e0e71308e27 100644 --- a/cirq-core/cirq/sim/act_on_state_vector_args.py +++ b/cirq-core/cirq/sim/act_on_state_vector_args.py @@ -174,11 +174,13 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: ) return bits - def _on_copy(self, target: 'ActOnStateVectorArgs'): + def _on_copy(self, target: 'cirq.ActOnStateVectorArgs'): target.target_tensor = self.target_tensor.copy() target.available_buffer = self.available_buffer.copy() - def _on_kronecker_product(self, other: 'ActOnStateVectorArgs', target: 'ActOnStateVectorArgs'): + def _on_kronecker_product( + self, other: 'cirq.ActOnStateVectorArgs', target: 'cirq.ActOnStateVectorArgs' + ): target_tensor = transformations.state_vector_kronecker_product( self.target_tensor, other.target_tensor ) @@ -188,8 +190,8 @@ def _on_kronecker_product(self, other: 'ActOnStateVectorArgs', target: 'ActOnSta def _on_factor( self, qubits: Sequence['cirq.Qid'], - extracted: 'ActOnStateVectorArgs', - remainder: 'ActOnStateVectorArgs', + extracted: 'cirq.ActOnStateVectorArgs', + remainder: 'cirq.ActOnStateVectorArgs', validate=True, atol=1e-07, ): @@ -203,7 +205,7 @@ def _on_factor( remainder.available_buffer = np.empty_like(remainder_tensor) def _on_transpose_to_qubit_order( - self, qubits: Sequence['cirq.Qid'], target: 'ActOnStateVectorArgs' + self, qubits: Sequence['cirq.Qid'], target: 'cirq.ActOnStateVectorArgs' ): axes = self.get_axes(qubits) new_tensor = transformations.transpose_state_vector_to_axis_order(self.target_tensor, axes) diff --git a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py index 4f5bcce2619..37396495a76 100644 --- a/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_clifford_tableau_args.py @@ -22,7 +22,6 @@ from cirq.ops import pauli_gates from cirq.ops.clifford_gate import SingleQubitCliffordGate from cirq.protocols import has_unitary, num_qubits, unitary -from cirq.qis.clifford_tableau import CliffordTableau from cirq.sim.act_on_args import ActOnArgs from cirq.type_workarounds import NotImplementedType @@ -39,7 +38,7 @@ class ActOnCliffordTableauArgs(ActOnArgs): def __init__( self, - tableau: CliffordTableau, + tableau: 'cirq.CliffordTableau', prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any], qubits: Sequence['cirq.Qid'] = None, diff --git a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py index 62154568052..ce3d3e20710 100644 --- a/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py +++ b/cirq-core/cirq/sim/clifford/act_on_stabilizer_ch_form_args.py @@ -21,7 +21,6 @@ from cirq.ops.clifford_gate import SingleQubitCliffordGate from cirq.protocols import has_unitary, num_qubits, unitary from cirq.sim.act_on_args import ActOnArgs -from cirq.sim.clifford.stabilizer_state_ch_form import StabilizerStateChForm from cirq.type_workarounds import NotImplementedType if TYPE_CHECKING: @@ -38,7 +37,7 @@ class ActOnStabilizerCHFormArgs(ActOnArgs): def __init__( self, - state: StabilizerStateChForm, + state: 'cirq.StabilizerStateChForm', prng: np.random.RandomState, log_of_measurement_results: Dict[str, Any], qubits: Sequence['cirq.Qid'] = None, diff --git a/cirq-core/cirq/sim/clifford/clifford_simulator.py b/cirq-core/cirq/sim/clifford/clifford_simulator.py index 194373625a4..7d49f072e26 100644 --- a/cirq-core/cirq/sim/clifford/clifford_simulator.py +++ b/cirq-core/cirq/sim/clifford/clifford_simulator.py @@ -34,17 +34,17 @@ import numpy as np import cirq -from cirq import study, protocols, value +from cirq import protocols, value from cirq.protocols import act_on from cirq.sim import clifford, simulator_base class CliffordSimulator( simulator_base.SimulatorBase[ - 'CliffordSimulatorStepResult', - 'CliffordTrialResult', - 'CliffordState', - clifford.ActOnStabilizerCHFormArgs, + 'cirq.CliffordSimulatorStepResult', + 'cirq.CliffordTrialResult', + 'cirq.CliffordState', + 'cirq.ActOnStabilizerCHFormArgs', ], ): """An efficient simulator for Clifford circuits.""" @@ -66,10 +66,10 @@ def is_supported_operation(op: 'cirq.Operation') -> bool: def _create_partial_act_on_args( self, - initial_state: Union[int, clifford.ActOnStabilizerCHFormArgs], + initial_state: Union[int, 'cirq.ActOnStabilizerCHFormArgs'], qubits: Sequence['cirq.Qid'], logs: Dict[str, Any], - ) -> clifford.ActOnStabilizerCHFormArgs: + ) -> 'cirq.ActOnStabilizerCHFormArgs': """Creates the ActOnStabilizerChFormArgs for a circuit. Args: @@ -104,7 +104,7 @@ def _create_step_result( def _create_simulator_trial_result( self, - params: study.ParamResolver, + params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], final_step_result: 'CliffordSimulatorStepResult', ): @@ -121,9 +121,9 @@ class CliffordTrialResult( ): def __init__( self, - params: study.ParamResolver, + params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'CliffordSimulatorStepResult', + final_step_result: 'cirq.CliffordSimulatorStepResult', ) -> None: super().__init__( params=params, measurements=measurements, final_step_result=final_step_result @@ -144,7 +144,7 @@ def _repr_pretty_(self, p: Any, cycle: bool): class CliffordSimulatorStepResult( - simulator_base.StepResultBase['clifford.CliffordState', 'clifford.ActOnStabilizerCHFormArgs'] + simulator_base.StepResultBase['cirq.CliffordState', 'cirq.ActOnStabilizerCHFormArgs'] ): """A `StepResult` that includes `StateVectorMixin` methods.""" @@ -200,7 +200,7 @@ class CliffordState: Gates and measurements are applied to each representation in O(n^2) time. """ - def __init__(self, qubit_map, initial_state: Union[int, clifford.StabilizerStateChForm] = 0): + def __init__(self, qubit_map, initial_state: Union[int, 'cirq.StabilizerStateChForm'] = 0): self.qubit_map = qubit_map self.n = len(qubit_map) @@ -226,7 +226,7 @@ def _from_json_dict_(cls, qubit_map, ch_form, **kwargs): def _value_equality_values_(self) -> Any: return self.qubit_map, self.ch_form - def copy(self) -> 'CliffordState': + def copy(self) -> 'cirq.CliffordState': state = CliffordState(self.qubit_map) state.ch_form = self.ch_form.copy() diff --git a/cirq-core/cirq/sim/clifford/stabilizer_sampler.py b/cirq-core/cirq/sim/clifford/stabilizer_sampler.py index e2ef5ce514e..c6145d77b6e 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_sampler.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_sampler.py @@ -17,7 +17,7 @@ import numpy as np import cirq -from cirq import circuits, protocols, value +from cirq import protocols, value from cirq.qis.clifford_tableau import CliffordTableau from cirq.sim.clifford.act_on_clifford_tableau_args import ActOnCliffordTableauArgs from cirq.work import sampler @@ -51,7 +51,7 @@ def run_sweep( results.append(cirq.Result(params=param_resolver, measurements=measurements)) return results - def _run(self, circuit: circuits.AbstractCircuit, repetitions: int) -> Dict[str, np.ndarray]: + def _run(self, circuit: 'cirq.AbstractCircuit', repetitions: int) -> Dict[str, np.ndarray]: measurements: Dict[str, List[np.ndarray]] = { key: [] for key in protocols.measurement_key_names(circuit) diff --git a/cirq-core/cirq/sim/density_matrix_simulator.py b/cirq-core/cirq/sim/density_matrix_simulator.py index 2cecac2f76d..70b223e28a2 100644 --- a/cirq-core/cirq/sim/density_matrix_simulator.py +++ b/cirq-core/cirq/sim/density_matrix_simulator.py @@ -31,10 +31,10 @@ class DensityMatrixSimulator( simulator_base.SimulatorBase[ - 'DensityMatrixStepResult', - 'DensityMatrixTrialResult', - 'DensityMatrixSimulatorState', - act_on_density_matrix_args.ActOnDensityMatrixArgs, + 'cirq.DensityMatrixStepResult', + 'cirq.DensityMatrixTrialResult', + 'cirq.DensityMatrixSimulatorState', + 'cirq.ActOnDensityMatrixArgs', ], simulator.SimulatesExpectationValues, ): @@ -226,10 +226,10 @@ def _create_step_result( def _create_simulator_trial_result( self, - params: study.ParamResolver, + params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'DensityMatrixStepResult', - ) -> 'DensityMatrixTrialResult': + final_step_result: 'cirq.DensityMatrixStepResult', + ) -> 'cirq.DensityMatrixTrialResult': return DensityMatrixTrialResult( params=params, measurements=measurements, final_step_result=final_step_result ) @@ -239,8 +239,8 @@ def simulate_expectation_values_sweep( self, program: 'cirq.AbstractCircuit', observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']], - params: 'study.Sweepable', - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + params: 'cirq.Sweepable', + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, initial_state: Any = None, permit_terminal_measurements: bool = False, ) -> List[List[float]]: @@ -270,9 +270,7 @@ def simulate_expectation_values_sweep( class DensityMatrixStepResult( - simulator_base.StepResultBase[ - 'DensityMatrixSimulatorState', act_on_density_matrix_args.ActOnDensityMatrixArgs - ] + simulator_base.StepResultBase['cirq.DensityMatrixSimulatorState', 'cirq.ActOnDensityMatrixArgs'] ): """A single step in the simulation of the DensityMatrixSimulator. @@ -284,7 +282,7 @@ class DensityMatrixStepResult( def __init__( self, sim_state: 'cirq.OperationTarget[cirq.ActOnDensityMatrixArgs]', - simulator: DensityMatrixSimulator = None, + simulator: 'cirq.DensityMatrixSimulator' = None, dtype: 'DTypeLike' = np.complex64, ): """DensityMatrixStepResult. @@ -300,7 +298,7 @@ def __init__( self._density_matrix: Optional[np.ndarray] = None self._simulator = simulator - def _simulator_state(self) -> 'DensityMatrixSimulatorState': + def _simulator_state(self) -> 'cirq.DensityMatrixSimulatorState': return DensityMatrixSimulatorState(self.density_matrix(copy=False), self._qubit_mapping) def set_density_matrix(self, density_matrix_repr: Union[int, np.ndarray]): @@ -380,7 +378,7 @@ class DensityMatrixSimulatorState: ordering of the basis in density_matrix. """ - def __init__(self, density_matrix: np.ndarray, qubit_map: Dict[ops.Qid, int]) -> None: + def __init__(self, density_matrix: np.ndarray, qubit_map: Dict['cirq.Qid', int]) -> None: self.density_matrix = density_matrix self.qubit_map = qubit_map self._qid_shape = simulator._qubit_map_to_shape(qubit_map) @@ -444,9 +442,9 @@ class DensityMatrixTrialResult( def __init__( self, - params: study.ParamResolver, + params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: DensityMatrixStepResult, + final_step_result: 'cirq.DensityMatrixStepResult', ) -> None: super().__init__( params=params, measurements=measurements, final_step_result=final_step_result diff --git a/cirq-core/cirq/sim/mux.py b/cirq-core/cirq/sim/mux.py index 8fbdc8ff64d..bce667d9d18 100644 --- a/cirq-core/cirq/sim/mux.py +++ b/cirq-core/cirq/sim/mux.py @@ -49,11 +49,11 @@ def sample( program: 'cirq.Circuit', *, noise: 'cirq.NOISE_MODEL_LIKE' = None, - param_resolver: Optional[study.ParamResolver] = None, + param_resolver: Optional['cirq.ParamResolver'] = None, repetitions: int = 1, dtype: Type[np.number] = np.complex64, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, -) -> study.Result: +) -> 'cirq.Result': """Simulates sampling from the given circuit. Args: @@ -103,8 +103,8 @@ def final_state_vector( program: 'cirq.CIRCUIT_LIKE', *, initial_state: 'cirq.STATE_VECTOR_LIKE' = 0, - param_resolver: study.ParamResolverOrSimilarType = None, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + param_resolver: 'cirq.ParamResolverOrSimilarType' = None, + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, dtype: Type[np.number] = np.complex64, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> 'np.ndarray': @@ -163,13 +163,13 @@ def final_state_vector( def sample_sweep( program: 'cirq.Circuit', - params: study.Sweepable, + params: 'cirq.Sweepable', *, noise: 'cirq.NOISE_MODEL_LIKE' = None, repetitions: int = 1, dtype: Type[np.number] = np.complex64, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, -) -> List[study.Result]: +) -> List['cirq.Result']: """Runs the supplied Circuit, mimicking quantum hardware. In contrast to run, this allows for sweeping over different parameter @@ -211,8 +211,8 @@ def final_density_matrix( *, noise: 'cirq.NOISE_MODEL_LIKE' = None, initial_state: 'cirq.STATE_VECTOR_LIKE' = 0, - param_resolver: study.ParamResolverOrSimilarType = None, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + param_resolver: 'cirq.ParamResolverOrSimilarType' = None, + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, dtype: Type[np.number] = np.complex64, seed: Optional[Union[int, np.random.RandomState]] = None, ignore_measurement_results: bool = True, diff --git a/cirq-core/cirq/sim/simulator.py b/cirq-core/cirq/sim/simulator.py index a9ac1d07759..4feade58888 100644 --- a/cirq-core/cirq/sim/simulator.py +++ b/cirq-core/cirq/sim/simulator.py @@ -71,17 +71,17 @@ class SimulatesSamples(work.Sampler, metaclass=abc.ABCMeta): def run_sweep( self, program: 'cirq.AbstractCircuit', - params: study.Sweepable, + params: 'cirq.Sweepable', repetitions: int = 1, - ) -> List[study.Result]: + ) -> List['cirq.Result']: return list(self.run_sweep_iter(program, params, repetitions)) def run_sweep_iter( self, program: 'cirq.AbstractCircuit', - params: study.Sweepable, + params: 'cirq.Sweepable', repetitions: int = 1, - ) -> Iterator[study.Result]: + ) -> Iterator['cirq.Result']: """Runs the supplied Circuit, mimicking quantum hardware. In contrast to run, this allows for sweeping over different parameter @@ -118,8 +118,8 @@ def run_sweep_iter( @abc.abstractmethod def _run( self, - circuit: circuits.AbstractCircuit, - param_resolver: study.ParamResolver, + circuit: 'cirq.AbstractCircuit', + param_resolver: 'cirq.ParamResolver', repetitions: int, ) -> Dict[str, np.ndarray]: """Run a simulation, mimicking quantum hardware. @@ -153,8 +153,8 @@ def compute_amplitudes( self, program: 'cirq.AbstractCircuit', bitstrings: Sequence[int], - param_resolver: 'study.ParamResolverOrSimilarType' = None, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + param_resolver: 'cirq.ParamResolverOrSimilarType' = None, + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, ) -> Sequence[complex]: """Computes the desired amplitudes. @@ -182,8 +182,8 @@ def compute_amplitudes_sweep( self, program: 'cirq.AbstractCircuit', bitstrings: Sequence[int], - params: study.Sweepable, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + params: 'cirq.Sweepable', + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, ) -> Sequence[Sequence[complex]]: """Wraps computed amplitudes in a list. @@ -195,8 +195,8 @@ def _compute_amplitudes_sweep_to_iter( self, program: 'cirq.AbstractCircuit', bitstrings: Sequence[int], - params: study.Sweepable, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + params: 'cirq.Sweepable', + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, ) -> Iterator[Sequence[complex]]: if type(self).compute_amplitudes_sweep == SimulatesAmplitudes.compute_amplitudes_sweep: raise RecursionError( @@ -211,8 +211,8 @@ def compute_amplitudes_sweep_iter( self, program: 'cirq.AbstractCircuit', bitstrings: Sequence[int], - params: study.Sweepable, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + params: 'cirq.Sweepable', + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, ) -> Iterator[Sequence[complex]]: """Computes the desired amplitudes. @@ -250,8 +250,8 @@ def simulate_expectation_values( self, program: 'cirq.AbstractCircuit', observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']], - param_resolver: 'study.ParamResolverOrSimilarType' = None, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + param_resolver: 'cirq.ParamResolverOrSimilarType' = None, + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, initial_state: Any = None, permit_terminal_measurements: bool = False, ) -> List[float]: @@ -298,8 +298,8 @@ def simulate_expectation_values_sweep( self, program: 'cirq.AbstractCircuit', observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']], - params: 'study.Sweepable', - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + params: 'cirq.Sweepable', + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, initial_state: Any = None, permit_terminal_measurements: bool = False, ) -> List[List[float]]: @@ -322,8 +322,8 @@ def _simulate_expectation_values_sweep_to_iter( self, program: 'cirq.AbstractCircuit', observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']], - params: 'study.Sweepable', - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + params: 'cirq.Sweepable', + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, initial_state: Any = None, permit_terminal_measurements: bool = False, ) -> Iterator[List[float]]: @@ -352,8 +352,8 @@ def simulate_expectation_values_sweep_iter( self, program: 'cirq.AbstractCircuit', observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']], - params: 'study.Sweepable', - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + params: 'cirq.Sweepable', + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, initial_state: Any = None, permit_terminal_measurements: bool = False, ) -> Iterator[List[float]]: @@ -408,8 +408,8 @@ class SimulatesFinalState( def simulate( self, program: 'cirq.AbstractCircuit', - param_resolver: 'study.ParamResolverOrSimilarType' = None, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + param_resolver: 'cirq.ParamResolverOrSimilarType' = None, + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, initial_state: Any = None, ) -> TSimulationTrialResult: """Simulates the supplied Circuit. @@ -437,8 +437,8 @@ def simulate( def simulate_sweep( self, program: 'cirq.AbstractCircuit', - params: study.Sweepable, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + params: 'cirq.Sweepable', + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, initial_state: Any = None, ) -> List[TSimulationTrialResult]: """Wraps computed states in a list. @@ -450,8 +450,8 @@ def simulate_sweep( def _simulate_sweep_to_iter( self, program: 'cirq.AbstractCircuit', - params: study.Sweepable, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + params: 'cirq.Sweepable', + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, initial_state: Any = None, ) -> Iterator[TSimulationTrialResult]: if type(self).simulate_sweep == SimulatesFinalState.simulate_sweep: @@ -462,8 +462,8 @@ def _simulate_sweep_to_iter( def simulate_sweep_iter( self, program: 'cirq.AbstractCircuit', - params: study.Sweepable, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + params: 'cirq.Sweepable', + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, initial_state: Any = None, ) -> Iterator[TSimulationTrialResult]: """Simulates the supplied Circuit. @@ -510,8 +510,8 @@ class SimulatesIntermediateState( def simulate_sweep_iter( self, program: 'cirq.AbstractCircuit', - params: study.Sweepable, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + params: 'cirq.Sweepable', + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, initial_state: Any = None, ) -> Iterator[TSimulationTrialResult]: """Simulates the supplied Circuit. @@ -557,9 +557,9 @@ def simulate_sweep_iter( def simulate_moment_steps( self, - circuit: circuits.AbstractCircuit, - param_resolver: 'study.ParamResolverOrSimilarType' = None, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + circuit: 'cirq.AbstractCircuit', + param_resolver: 'cirq.ParamResolverOrSimilarType' = None, + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, initial_state: Any = None, ) -> Iterator[TStepResult]: """Returns an iterator of StepResults for each moment simulated. @@ -590,8 +590,8 @@ def simulate_moment_steps( def _base_iterator( self, - circuit: circuits.AbstractCircuit, - qubit_order: ops.QubitOrderOrList, + circuit: 'cirq.AbstractCircuit', + qubit_order: 'cirq.QubitOrderOrList', initial_state: Any, ) -> Iterator[TStepResult]: """Iterator over StepResult from Moments of a Circuit. @@ -643,7 +643,7 @@ def _create_act_on_args( @abc.abstractmethod def _core_iterator( self, - circuit: circuits.AbstractCircuit, + circuit: 'cirq.AbstractCircuit', sim_state: 'cirq.OperationTarget[TActOnArgs]', all_measurements_are_terminal: bool = False, ) -> Iterator[TStepResult]: @@ -666,7 +666,7 @@ def _core_iterator( @abc.abstractmethod def _create_simulator_trial_result( self, - params: study.ParamResolver, + params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], final_step_result: TStepResult, ) -> TSimulationTrialResult: @@ -709,7 +709,7 @@ def _simulator_state(self) -> TSimulatorState: @abc.abstractmethod def sample( self, - qubits: List[ops.Qid], + qubits: List['cirq.Qid'], repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> np.ndarray: @@ -733,7 +733,7 @@ def sample( def sample_measurement_ops( self, - measurement_ops: List[ops.GateOperation], + measurement_ops: List['cirq.GateOperation'], repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> Dict[str, np.ndarray]: @@ -823,10 +823,10 @@ class SimulationTrialResult: def __init__( self, - params: study.ParamResolver, + params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], final_simulator_state: Any = None, - final_step_result: StepResult = None, + final_step_result: 'cirq.StepResult' = None, ) -> None: """Initializes the `SimulationTrialResult` class. @@ -892,7 +892,7 @@ def _value_equality_values_(self) -> Any: return self.params, measurements, self._final_simulator_state @property - def qubit_map(self) -> Dict[ops.Qid, int]: + def qubit_map(self) -> Dict['cirq.Qid', int]: """A map from Qid to index used to define the ordering of the basis in the result. """ @@ -902,7 +902,7 @@ def _qid_shape_(self) -> Tuple[int, ...]: return _qubit_map_to_shape(self.qubit_map) -def _qubit_map_to_shape(qubit_map: Dict[ops.Qid, int]) -> Tuple[int, ...]: +def _qubit_map_to_shape(qubit_map: Dict['cirq.Qid', int]) -> Tuple[int, ...]: qid_shape: List[int] = [-1] * len(qubit_map) try: for q, i in qubit_map.items(): @@ -914,7 +914,7 @@ def _qubit_map_to_shape(qubit_map: Dict[ops.Qid, int]) -> Tuple[int, ...]: return tuple(qid_shape) -def _verify_unique_measurement_keys(circuit: circuits.AbstractCircuit): +def _verify_unique_measurement_keys(circuit: 'cirq.AbstractCircuit'): result = collections.Counter( key for op in ops.flatten_op_tree(iter(circuit)) diff --git a/cirq-core/cirq/sim/simulator_base.py b/cirq-core/cirq/sim/simulator_base.py index 201eeca70b9..1dbd1e8d43e 100644 --- a/cirq-core/cirq/sim/simulator_base.py +++ b/cirq-core/cirq/sim/simulator_base.py @@ -33,7 +33,7 @@ import numpy as np -from cirq import circuits, ops, protocols, study, value, devices +from cirq import ops, protocols, study, value, devices from cirq.sim import ActOnArgsContainer from cirq.sim.operation_target import OperationTarget from cirq.sim.simulator import ( @@ -177,7 +177,7 @@ def _can_be_in_run_prefix(self, val: Any): def _core_iterator( self, - circuit: circuits.AbstractCircuit, + circuit: 'cirq.AbstractCircuit', sim_state: OperationTarget[TActOnArgs], all_measurements_are_terminal: bool = False, ) -> Iterator[TStepResultBase]: @@ -231,8 +231,8 @@ def _core_iterator( def _run( self, - circuit: circuits.AbstractCircuit, - param_resolver: study.ParamResolver, + circuit: 'cirq.AbstractCircuit', + param_resolver: 'cirq.ParamResolver', repetitions: int, ) -> Dict[str, np.ndarray]: """See definition in `cirq.SimulatesSamples`.""" @@ -286,8 +286,8 @@ def _run( def simulate_sweep_iter( self, program: 'cirq.AbstractCircuit', - params: study.Sweepable, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + params: 'cirq.Sweepable', + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, initial_state: Any = None, ) -> Iterator[TSimulationTrialResult]: """Simulates the supplied Circuit. @@ -400,7 +400,7 @@ def _merged_sim_state(self): def sample( self, - qubits: List[ops.Qid], + qubits: List['cirq.Qid'], repetitions: int = 1, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, ) -> np.ndarray: diff --git a/cirq-core/cirq/sim/sparse_simulator.py b/cirq-core/cirq/sim/sparse_simulator.py index 7c4fa5edea6..4a1fd5da9f0 100644 --- a/cirq-core/cirq/sim/sparse_simulator.py +++ b/cirq-core/cirq/sim/sparse_simulator.py @@ -220,7 +220,7 @@ def simulate_expectation_values_sweep_iter( program: 'cirq.AbstractCircuit', observables: Union['cirq.PauliSumLike', List['cirq.PauliSumLike']], params: 'cirq.Sweepable', - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, initial_state: Any = None, permit_terminal_measurements: bool = False, ) -> Iterator[List[float]]: @@ -252,7 +252,7 @@ class SparseSimulatorStep( def __init__( self, sim_state: 'cirq.OperationTarget[cirq.ActOnStateVectorArgs]', - simulator: Simulator = None, + simulator: 'cirq.Simulator' = None, dtype: 'DTypeLike' = np.complex64, ): """Results of a step of the simulator. @@ -269,7 +269,7 @@ def __init__( self._state_vector: Optional[np.ndarray] = None self._simulator = simulator - def _simulator_state(self) -> state_vector_simulator.StateVectorSimulatorState: + def _simulator_state(self) -> 'cirq.StateVectorSimulatorState': return state_vector_simulator.StateVectorSimulatorState( qubit_map=self.qubit_map, state_vector=self.state_vector(copy=False) ) diff --git a/cirq-core/cirq/sim/state_vector.py b/cirq-core/cirq/sim/state_vector.py index a6150194cbe..d448caca883 100644 --- a/cirq-core/cirq/sim/state_vector.py +++ b/cirq-core/cirq/sim/state_vector.py @@ -13,12 +13,12 @@ # limitations under the License. """Helpers for handling quantum state vectors.""" +import abc from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Sequence -import abc import numpy as np -from cirq import linalg, ops, qis, value +from cirq import linalg, qis, value from cirq.sim import simulator if TYPE_CHECKING: @@ -26,13 +26,12 @@ # For backwards compatibility and to make mypy happy: -from cirq.qis import STATE_VECTOR_LIKE # pylint: disable=unused-import,wrong-import-position class StateVectorMixin: """A mixin that provide methods for objects that have a state vector.""" - def __init__(self, qubit_map: Optional[Dict[ops.Qid, int]] = None, *args, **kwargs): + def __init__(self, qubit_map: Optional[Dict['cirq.Qid', int]] = None, *args, **kwargs): """Inits StateVectorMixin. Args: @@ -49,7 +48,7 @@ def __init__(self, qubit_map: Optional[Dict[ops.Qid, int]] = None, *args, **kwar self._qid_shape = None if qubit_map is None else qid_shape @property - def qubit_map(self) -> Dict[ops.Qid, int]: + def qubit_map(self) -> Dict['cirq.Qid', int]: return self._qubit_map def _qid_shape_(self) -> Tuple[int, ...]: @@ -98,7 +97,7 @@ def dirac_notation(self, decimals: int = 2) -> str: and non-zero floats of the specified accuracy.""" return qis.dirac_notation(self.state_vector(), decimals, qid_shape=self._qid_shape) - def density_matrix_of(self, qubits: List[ops.Qid] = None) -> np.ndarray: + def density_matrix_of(self, qubits: List['cirq.Qid'] = None) -> np.ndarray: r"""Returns the density matrix of the state. Calculate the density matrix for the system on the list, qubits. diff --git a/cirq-core/cirq/sim/state_vector_simulator.py b/cirq-core/cirq/sim/state_vector_simulator.py index d0fcb64c8f8..3e92bc2a31c 100644 --- a/cirq-core/cirq/sim/state_vector_simulator.py +++ b/cirq-core/cirq/sim/state_vector_simulator.py @@ -29,10 +29,9 @@ import numpy as np -from cirq import ops, study, value, qis +from cirq import ops, value, qis from cirq._compat import proper_repr from cirq.sim import simulator, state_vector, simulator_base -from cirq.sim.act_on_state_vector_args import ActOnStateVectorArgs if TYPE_CHECKING: import cirq @@ -45,9 +44,9 @@ class SimulatesIntermediateStateVector( Generic[TStateVectorStepResult], simulator_base.SimulatorBase[ TStateVectorStepResult, - 'StateVectorTrialResult', - 'StateVectorSimulatorState', - ActOnStateVectorArgs, + 'cirq.StateVectorTrialResult', + 'cirq.StateVectorSimulatorState', + 'cirq.ActOnStateVectorArgs', ], simulator.SimulatesAmplitudes, metaclass=abc.ABCMeta, @@ -74,10 +73,10 @@ def __init__( def _create_simulator_trial_result( self, - params: study.ParamResolver, + params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: 'StateVectorStepResult', - ) -> 'StateVectorTrialResult': + final_step_result: 'cirq.StateVectorStepResult', + ) -> 'cirq.StateVectorTrialResult': return StateVectorTrialResult( params=params, measurements=measurements, final_step_result=final_step_result ) @@ -86,8 +85,8 @@ def compute_amplitudes_sweep_iter( self, program: 'cirq.AbstractCircuit', bitstrings: Sequence[int], - params: study.Sweepable, - qubit_order: ops.QubitOrderOrList = ops.QubitOrder.DEFAULT, + params: 'cirq.Sweepable', + qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT, ) -> Iterator[Sequence[complex]]: if isinstance(bitstrings, np.ndarray) and len(bitstrings.shape) > 1: raise ValueError( @@ -159,9 +158,9 @@ class StateVectorTrialResult( def __init__( self, - params: study.ParamResolver, + params: 'cirq.ParamResolver', measurements: Dict[str, np.ndarray], - final_step_result: StateVectorStepResult, + final_step_result: 'cirq.StateVectorStepResult', ) -> None: super().__init__( params=params,